## Tutorial of Interventions on Non-transformer Model: GRUs

In [1]:
__author__ = "Zhengxuan Wu"
__version__ = "12/22/2023"

### Overview

This tutorials show how to use this library on recurrent neural networks, such as GRUs. The set-ups are pretty much the same as standard transformer-based models.

### Set-up

In [2]:
try:
    # This library is our indicator that the required installs
    # need to be done.
    import transformers
    import sys

    sys.path.append("align-transformers/")
except ModuleNotFoundError:
    !git clone https://github.com/frankaging/align-transformers.git
    !pip install -r align-transformers/requirements.txt
    import sys

    sys.path.append("align-transformers/")

In [3]:
import sys

sys.path.append("../..")

import torch
import pandas as pd
from models.basic_utils import embed_to_distrib, top_vals, format_token
from models.configuration_intervenable_model import (
    IntervenableRepresentationConfig,
    IntervenableConfig,
)
from models.intervenable_base import IntervenableModel
from models.interventions import (
    VanillaIntervention,
    RotatedSpaceIntervention,
    LowRankRotatedSpaceIntervention,
)
from models.gru.modelings_gru import GRUConfig
from models.gru.modelings_intervenable_gru import create_gru_classifier

%config InlineBackend.figure_formats = ['svg']
from plotnine import (
    ggplot,
    geom_tile,
    aes,
    facet_wrap,
    theme,
    element_text,
    geom_bar,
    geom_hline,
    scale_y_log10,
)

config, tokenizer, gru = create_gru_classifier(GRUConfig(n_layer=1, h_dim=2))

[2024-01-10 15:13:59,567] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)
loaded model


### Vanilla intervention on multiple time steps
Recurrent neural networks like GRUs contain stateful representations, where if we intervene on one state, the causal effects should ripple through later states. Intervening on future states may also block interventions on earlier states if interventions happen in the information bottleneck. 

In [4]:
intervenable_config = IntervenableConfig(
    intervenable_model_type=type(gru),
    intervenable_representations=[
        IntervenableRepresentationConfig(
            0,
            "cell_output",
            "t",
            1,
        ),
    ],
    intervenable_interventions_type=VanillaIntervention,
)
intervenable = IntervenableModel(intervenable_config, gru)

base = {"inputs_embeds": torch.rand(10, 10, 2)}
source = {"inputs_embeds": torch.rand(10, 10, 2)}

In [5]:
_, counterfactual_outputs_all = intervenable(
    base,
    [source],
    {
        "sources->base": ([[[0, 2, 4]] * 10], [[[0, 5, 7]] * 10])
    },  # this suppose to intervene once, but it will be called 10 times.
)

_, counterfactual_outputs_last = intervenable(
    base,
    [source],
    {
        "sources->base": ([[[4]] * 10], [[[7]] * 10])
    },  # this suppose to intervene once, but it will be called 10 times.
)

print(torch.equal(counterfactual_outputs_all[0], counterfactual_outputs_last[0]))

True


In [6]:
_, counterfactual_outputs_all = intervenable(
    base,
    [source],
    {
        "sources->base": ([[[0, 2]] * 10], [[[0, 5]] * 10])
    },  # this suppose to intervene once, but it will be called 10 times.
)

_, counterfactual_outputs_last = intervenable(
    base,
    [source],
    {
        "sources->base": ([[[2]] * 10], [[[5]] * 10])
    },  # this suppose to intervene once, but it will be called 10 times.
)

print(torch.equal(counterfactual_outputs_all[0], counterfactual_outputs_last[0]))

True


### Subspace DAS by intervening a single time step

In [7]:
intervenable_config = IntervenableConfig(
    intervenable_model_type=type(gru),
    intervenable_representations=[
        IntervenableRepresentationConfig(
            0,
            "cell_output",
            "t",
            1,
            intervenable_low_rank_dimension=2,
        ),
    ],
    intervenable_interventions_type=LowRankRotatedSpaceIntervention,
)
intervenable = IntervenableModel(intervenable_config, gru)
base = {"inputs_embeds": torch.rand(1, 1, 2)}
source = {"inputs_embeds": torch.rand(1, 1, 2)}
print("base", intervenable(base)[0][0])
print("source", intervenable(source)[0][0])

base tensor([[-0.1404, -0.0601]])
source tensor([[0.0034, 0.0207]])


In [8]:
_, counterfactual_outputs = intervenable(
    base, [source], {"sources->base": ([[[0]]], [[[0]]])}
)
print(counterfactual_outputs[0])  # this should be the same as the source output
counterfactual_outputs[
    0
].sum().backward()  # fake call to make sure gradient can be populated

tensor([[0.0034, 0.0207]], grad_fn=<MmBackward0>)


In [9]:
intervenable_config = IntervenableConfig(
    intervenable_model_type=type(gru),
    intervenable_representations=[
        IntervenableRepresentationConfig(
            0,
            "cell_output",
            "t",
            1,
            intervenable_low_rank_dimension=2,
            subspace_partition=[[0, 1], [1, 2]],  # partition into two sets of subspaces
            intervention_link_key=0,  # linked ones target the same subspace
        ),
        IntervenableRepresentationConfig(
            0,
            "cell_output",
            "t",
            1,
            intervenable_low_rank_dimension=2,
            subspace_partition=[[0, 1], [1, 2]],  # partition into two sets of subspaces
            intervention_link_key=0,  # linked ones target the same subspace
        ),
    ],
    intervenable_interventions_type=LowRankRotatedSpaceIntervention,
)
intervenable = IntervenableModel(intervenable_config, gru)

base = {"inputs_embeds": torch.rand(1, 1, 2)}
source = {"inputs_embeds": torch.rand(1, 1, 2)}
print("base", intervenable(base)[0][0])
print("source", intervenable(source)[0][0])

base tensor([[-0.1157, -0.0500]])
source tensor([[0.0614, 0.0535]])


In [10]:
_, counterfactual_outputs = intervenable(
    base,
    [source, source],
    {"sources->base": ([[[0]], [[0]]], [[[0]], [[0]]])},
    subspaces=[[[0]], [[1]]],
)
print(counterfactual_outputs[0])  # this should be the same as the source output
counterfactual_outputs[
    0
].sum().backward()  # fake call to make sure gradient can be populated

tensor([[0.0614, 0.0535]], grad_fn=<MmBackward0>)
