## Tutorial of Interchange Intervention Training

In [1]:
__author__ = "Zhengxuan Wu"
__version__ = "01/11/2024"

### Overview

[Interchange Intervention Training](https://arxiv.org/abs/2112.00826) (IIT) is a technique to train neural networks to be interpretable in a data-driven fashion. As it says in its name, it leverages intervention signals to train a neural network. As a result, the network's activations are highly interpretable in a sense that we can intervene them at inference time to get interpretable counterfactual behaviors.

This library supports IIT as it is essentially a vanilla intervention plus enabling gradients for all the model parameters.

### Set-up

In [2]:
# try:
#     # This library is our indicator that the required installs
#     # need to be done.
#     import pyvene

# except ModuleNotFoundError:
#     !pip install git+https://github.com/frankaging/pyvene.git

In [3]:
# import pandas as pd
import sys
sys.path.append("../..")

from pyvene.models.basic_utils import (
    embed_to_distrib,
    top_vals,
    format_token,
    count_parameters
)

from pyvene.models.gpt2.modelings_intervenable_gpt2 import create_gpt2

from pyvene.models.intervenable_base import IntervenableModel
from pyvene.models.interventions import VanillaIntervention
from pyvene.models.interventions import RotatedSpaceIntervention

from pyvene.models.configuration_intervenable_model import (
    IntervenableConfig, IntervenableRepresentationConfig, VanillaIntervention
)

config, tokenizer, gpt = create_gpt2()

[2024-01-12 02:39:07,117] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
loaded model


In [4]:
intervenable_config = IntervenableConfig(
    intervenable_model_type=type(gpt),
    intervenable_representations=[
        IntervenableRepresentationConfig(
            2,
            "mlp_activation",
            "pos",
            1,
        ),
    ],
    intervenable_interventions_type=VanillaIntervention,
)
intervenable = IntervenableModel(intervenable_config, gpt)

base = tokenizer("The capital of Spain is", return_tensors="pt")
sources = [
    tokenizer("The capital of Italy is", return_tensors="pt"),
]

In [5]:
intervenable.count_parameters()

0

We just need to turn on gradients on all the model parameters

In [6]:
intervenable.enable_model_gradients()
intervenable.count_parameters()

124439808

In [8]:
base_outputs, counterfactual_outputs = intervenable(
    base, sources, {"sources->base": ([[[3]]], [[[3]]])}
)

In [9]:
counterfactual_outputs.last_hidden_state - base_outputs.last_hidden_state

tensor([[[ 0.1292, -0.0520,  0.1511,  ..., -0.1309,  0.0113,  0.0342],
         [ 0.0603, -0.7758,  0.1832,  ...,  0.2912,  0.2868,  0.2893],
         [-0.5429, -0.3998,  0.0891,  ..., -0.3754,  0.1311,  0.2489],
         [ 0.2532,  0.1299,  0.0409,  ..., -0.2040, -0.1513,  0.3049],
         [ 0.1114,  0.1318,  0.4405,  ...,  0.1814,  0.2783,  0.0206]]],
       grad_fn=<SubBackward0>)

In [10]:
counterfactual_outputs.last_hidden_state.sum().backward()

check any model grad

In [18]:
gpt.h[0].mlp.c_fc.weight.grad

tensor([[-2.7394e-01, -9.8538e-03,  2.1004e-02,  ..., -1.9908e-02,
         -3.4756e-02, -8.5781e-02],
        [-1.3462e-01, -1.8148e-03, -2.9549e-02,  ..., -5.2381e-02,
         -1.0547e-01,  1.9281e-01],
        [ 1.4480e-01,  9.1471e-04, -1.4906e-02,  ...,  2.0330e-02,
          3.9505e-02, -6.6796e-02],
        ...,
        [ 2.4939e-01, -2.0916e-03, -3.0832e-03,  ...,  2.0648e-02,
         -1.5234e-02, -1.5796e-04],
        [-5.5724e-02, -9.8790e-03,  5.5369e-02,  ...,  1.8155e-02,
          2.2969e-02, -4.6784e-02],
        [ 1.6450e-01,  1.9703e-02, -2.0497e-02,  ...,  2.5583e-02,
         -1.5143e-02, -2.4821e-01]])