## Tutorial of More Complex Interventions Use

In [1]:
__author__ = "Zhengxuan Wu"
__version__ = "12/19/2023"

### Overview

The basic tutorials cover simple usages of interventions. Here, we showcase some more advance usages of this library, which can support flexible interventions by grouping interventions together, skipping interventions when needed, etc... This is a live tutorial which encapsulates a set of advanced usages together.

### Set-up

In [2]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import transformers
    import sys

    sys.path.append("align-transformers/")
except ModuleNotFoundError:
    !git clone https://github.com/frankaging/align-transformers.git
    !pip install -r align-transformers/requirements.txt
    import sys

    sys.path.append("align-transformers/")

In [3]:
import sys

sys.path.append("../..")

import torch
import pandas as pd
from models.basic_utils import embed_to_distrib, top_vals, format_token
from models.configuration_intervenable_model import (
    IntervenableRepresentationConfig,
    IntervenableConfig,
)
from models.intervenable_base import IntervenableModel
from models.interventions import VanillaIntervention, LowRankRotatedSpaceIntervention
from models.gpt2.modelings_intervenable_gpt2 import create_gpt2

%config InlineBackend.figure_formats = ['svg']
from plotnine import (
    ggplot,
    geom_tile,
    aes,
    facet_wrap,
    theme,
    element_text,
    geom_bar,
    geom_hline,
    scale_y_log10,
)

config, tokenizer, gpt = create_gpt2(cache_dir="../../../.huggingface_cache")

[2024-01-10 15:09:51,716] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
loaded model


### Non-group-based Interventions v.s. Group-based Interventions

Two same sources are used to intervene at two locations.

In [4]:
intervenable_config = IntervenableConfig(
    intervenable_model_type=type(gpt),
    intervenable_representations=[
        IntervenableRepresentationConfig(
            0,
            "block_output",
            "pos",
            1,
        ),
        IntervenableRepresentationConfig(
            2,
            "block_output",
            "pos",
            1,
        ),
    ],
    intervenable_interventions_type=VanillaIntervention,
)
intervenable = IntervenableModel(intervenable_config, gpt)

base = tokenizer("The capital of Spain is", return_tensors="pt")
sources = [
    tokenizer("The capital of Italy is", return_tensors="pt"),
    tokenizer("The capital of Italy is", return_tensors="pt"),
]

In [5]:
_, counterfactual_outputs_no_group = intervenable(
    base, sources, {"sources->base": ([[[3]], [[4]]], [[[3]], [[4]]])}
)

One single source is used for all interventions in the group

In [6]:
intervenable_config = IntervenableConfig(
    intervenable_model_type=type(gpt),
    intervenable_representations=[
        IntervenableRepresentationConfig(0, "block_output", "pos", 1, group_key=0),
        IntervenableRepresentationConfig(2, "block_output", "pos", 1, group_key=0),
    ],
    intervenable_interventions_type=VanillaIntervention,
)
intervenable = IntervenableModel(intervenable_config, gpt)

base = tokenizer("The capital of Spain is", return_tensors="pt")
sources = [tokenizer("The capital of Italy is", return_tensors="pt")]

In [7]:
_, counterfactual_outputs_group = intervenable(
    base, sources, {"sources->base": ([[[3]], [[4]]], [[[3]], [[4]]])}
)

In [8]:
torch.equal(
    counterfactual_outputs_no_group.last_hidden_state,
    counterfactual_outputs_group.last_hidden_state,
)

True

### Smart skipping interventions by passing in None

This library respects the intervention list as the source of the truth when accepting different inputs. However, sometimes, we may only need to intervene on a partial list of all listed interventions. We can do that by passing in None in the source input list.

In [9]:
intervenable_config = IntervenableConfig(
    intervenable_model_type=type(gpt),
    intervenable_representations=[
        IntervenableRepresentationConfig(
            0,
            "block_output",
            "pos",
            1,
        ),
        IntervenableRepresentationConfig(
            0,
            "block_output",
            "pos",
            1,
        ),
        IntervenableRepresentationConfig(
            0,
            "block_output",
            "pos",
            1,
        ),
    ],
    intervenable_interventions_type=VanillaIntervention,
)
intervenable = IntervenableModel(intervenable_config, gpt)

base = tokenizer("The capital of Spain is", return_tensors="pt")
source = tokenizer("The capital of Italy is", return_tensors="pt")

In [10]:
_, counterfactual_outputs_1 = intervenable(
    base,
    [None, None, source],
    {"sources->base": ([None, None, [[4]]], [None, None, [[4]]])},
)
_, counterfactual_outputs_2 = intervenable(
    base,
    [None, source, None],
    {"sources->base": ([None, [[4]], None], [None, [[4]], None])},
)
_, counterfactual_outputs_3 = intervenable(
    base,
    [source, None, None],
    {"sources->base": ([[[4]], None, None], [[[4]], None, None])},
)

In [11]:
print(
    torch.equal(
        counterfactual_outputs_1.last_hidden_state,
        counterfactual_outputs_2.last_hidden_state,
    ),
    torch.equal(
        counterfactual_outputs_2.last_hidden_state,
        counterfactual_outputs_3.last_hidden_state,
    ),
)

True True


### Weight-sharing interventions targetting different subspaces

Trainable interventions also support weight sharing. This is useful if two interventions are targetting different subspaces of a new basis. This is different from one intervention with paritioned subspaces. The latter case only allow intervening at one subspace at a time, which could be useful as well. However, weight-sharing with smart skipping may be suffice for all the use-cases.

In [12]:
intervenable_config = IntervenableConfig(
    intervenable_model_type=type(gpt),
    intervenable_representations=[
        IntervenableRepresentationConfig(
            0,
            "block_output",
            "pos",
            1,
            intervenable_low_rank_dimension=2,
            subspace_partition=[[0, 1], [1, 2]],
            intervention_link_key=0,  # create sym link across interventions
        ),
        IntervenableRepresentationConfig(
            0,
            "block_output",
            "pos",
            1,
            intervenable_low_rank_dimension=2,
            subspace_partition=[[0, 1], [1, 2]],
            intervention_link_key=0,  # create sym link across interventions
        ),
    ],
    intervenable_interventions_type=LowRankRotatedSpaceIntervention,
)
intervenable = IntervenableModel(intervenable_config, gpt)

base = tokenizer("The capital of Spain is", return_tensors="pt")
source = tokenizer("The capital of Italy is", return_tensors="pt")

In [13]:
_, counterfactual_outputs_1 = intervenable(
    base,
    [None, source],
    {"sources->base": ([None, [[4]]], [None, [[4]]])},
    subspaces=[None, [[1]]],
)
_, counterfactual_outputs_2 = intervenable(
    base,
    [source, None],
    {"sources->base": ([[[4]], None], [[[4]], None])},
    subspaces=[[[1]], None],
)
_, counterfactual_outputs_3 = intervenable(
    base,
    [source, source],
    {"sources->base": ([[[4]], [[4]]], [[[4]], [[4]]])},
    subspaces=[[[0]], [[1]]],
)
_, counterfactual_outputs_4 = intervenable(
    base,
    [source, source],
    {"sources->base": ([[[4]], [[4]]], [[[4]], [[4]]])},
    subspaces=[[[1]], [[0]]],
)

In [14]:
print(
    torch.equal(
        counterfactual_outputs_1.last_hidden_state,
        counterfactual_outputs_2.last_hidden_state,
    ),
    torch.equal(
        counterfactual_outputs_2.last_hidden_state,
        counterfactual_outputs_3.last_hidden_state,
    ),
    torch.allclose(
        counterfactual_outputs_1.last_hidden_state,
        counterfactual_outputs_3.last_hidden_state,
        atol=1e-5,  # bmm in different order will result in slightly different results
    ),
    torch.allclose(
        counterfactual_outputs_3.last_hidden_state,
        counterfactual_outputs_4.last_hidden_state,
        atol=1e-5,  # bmm in different order will result in slightly different results
    ),
)

True False False True


In [16]:
counterfactual_outputs_4[0].sum().backward()

In [17]:
# this is an example about order matters for percision
x = torch.randn(10, 10, 10)
s1 = x.sum()
s2 = x.sum(0).sum(0).sum(0)
print((s1 - s2).abs().max())

tensor(3.8147e-06)
