## Tutorial of Interventions on Non-transformer Model: MLPs

In [1]:
__author__ = "Zhengxuan Wu"
__version__ = "12/20/2023"

### Overview

This tutorials show how to use this library on non-transformer models, such as MLPs. The set-ups are pretty much the same as standard transformer-based models.

### Set-up

In [2]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import transformers
    import sys

    sys.path.append("align-transformers/")
except ModuleNotFoundError:
    !git clone https://github.com/frankaging/align-transformers.git
    !pip install -r align-transformers/requirements.txt
    import sys

    sys.path.append("align-transformers/")

In [3]:
import sys

sys.path.append("../..")

import torch
import pandas as pd
from models.basic_utils import embed_to_distrib, top_vals, format_token
from models.configuration_intervenable_model import (
    IntervenableRepresentationConfig,
    IntervenableConfig,
)
from models.intervenable_base import IntervenableModel
from models.interventions import (
    VanillaIntervention,
    RotatedSpaceIntervention,
    LowRankRotatedSpaceIntervention,
)
from models.mlp.modelings_mlp import MLPConfig
from models.mlp.modelings_intervenable_mlp import create_mlp_classifier

%config InlineBackend.figure_formats = ['svg']
from plotnine import (
    ggplot,
    geom_tile,
    aes,
    facet_wrap,
    theme,
    element_text,
    geom_bar,
    geom_hline,
    scale_y_log10,
)

config, tokenizer, mlp = create_mlp_classifier(MLPConfig(h_dim=32, n_layer=1))

[2024-01-10 15:12:21,053] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
loaded model


### Intervene in middle layer by partitioning representations into subspaces

MLP layer may contain only a single "token" representation each layer. As a result, we often want to intervene on a subspace of this "token" representation to localize a concept.

In [4]:
intervenable_config = IntervenableConfig(
    intervenable_model_type=type(mlp),
    intervenable_representations=[
        IntervenableRepresentationConfig(
            0,
            "block_output",
            "pos",  # mlp layer creates a single token reprs
            1,
            subspace_partition=[
                [0, 16],
                [16, 32],
            ],  # partition into two sets of subspaces
        ),
    ],
    intervenable_interventions_type=RotatedSpaceIntervention,
)
intervenable = IntervenableModel(intervenable_config, mlp)

base = {"inputs_embeds": torch.rand(1, 1, 32)}
source = {"inputs_embeds": torch.rand(1, 1, 32)}
print("base", intervenable(base))
print("source", intervenable(source))

base ((tensor([[ 0.2097, -0.0147]]),), None)
source ((tensor([[ 0.1747, -0.1433]]),), None)


In [7]:
_, counterfactual_outputs = intervenable(
    base, [source], {"sources->base": ([[[0]]], [[[0]]])}, subspaces=[[[1, 0]]]
)

In [8]:
counterfactual_outputs  # this should be the same as source.

(tensor([[ 0.1747, -0.1433]], grad_fn=<SqueezeBackward1>),)

### Intervene the subspace with multiple sources

In [9]:
intervenable_config = IntervenableConfig(
    intervenable_model_type=type(mlp),
    intervenable_representations=[
        IntervenableRepresentationConfig(
            0,
            "block_output",
            "pos",  # mlp layer creates a single token reprs
            1,
            intervenable_low_rank_dimension=32,
            subspace_partition=[
                [0, 16],
                [16, 32],
            ],  # partition into two sets of subspaces
            intervention_link_key=0,  # linked ones target the same subspace
        ),
        IntervenableRepresentationConfig(
            0,
            "block_output",
            "pos",  # mlp layer creates a single token reprs
            1,
            intervenable_low_rank_dimension=32,
            subspace_partition=[
                [0, 16],
                [16, 32],
            ],  # partition into two sets of subspaces
            intervention_link_key=0,  # linked ones target the same subspace
        ),
    ],
    intervenable_interventions_type=LowRankRotatedSpaceIntervention,
)
intervenable = IntervenableModel(intervenable_config, mlp)

base = {"inputs_embeds": torch.rand(10, 1, 32)}
source = {"inputs_embeds": torch.rand(10, 1, 32)}
print("base", intervenable(base))
print("source", intervenable(source))

base ((tensor([[ 0.1992, -0.0437],
        [ 0.1270,  0.0192],
        [ 0.2138, -0.1187],
        [ 0.2332, -0.1028],
        [ 0.1555, -0.1350],
        [ 0.1667, -0.0826],
        [ 0.1946, -0.1565],
        [ 0.1763, -0.1517],
        [ 0.2392,  0.0309],
        [ 0.1352, -0.1232]]),), None)
source ((tensor([[ 0.2338, -0.1565],
        [ 0.1794, -0.1778],
        [ 0.1446, -0.0479],
        [ 0.2647, -0.1311],
        [ 0.1848, -0.0395],
        [ 0.2853, -0.1353],
        [ 0.2116, -0.0744],
        [ 0.1642, -0.1196],
        [ 0.2374, -0.1412],
        [ 0.2045, -0.1211]]),), None)


In [11]:
_, counterfactual_outputs = intervenable(
    base,
    [source, source],
    {"sources->base": ([[[0]] * 10, [[0]] * 10], [[[0]] * 10, [[0]] * 10])},
    subspaces=[[[1]] * 10, [[0]] * 10],
)
print(counterfactual_outputs)  # this should be the same as the source output
counterfactual_outputs[
    0
].sum().backward()  # fake call to make sure gradient can be populated

(tensor([[ 0.2338, -0.1565],
        [ 0.1794, -0.1778],
        [ 0.1446, -0.0479],
        [ 0.2647, -0.1311],
        [ 0.1848, -0.0395],
        [ 0.2853, -0.1353],
        [ 0.2116, -0.0744],
        [ 0.1642, -0.1196],
        [ 0.2374, -0.1412],
        [ 0.2045, -0.1211]], grad_fn=<SqueezeBackward1>),)
