# Distributed Alignment Search (DAS): Searching for Linearly Encoded Concepts in Model Representations

In the last tutorial, we looked at <a href="https://arxiv.org/abs/2402.17700">RAVEL</a>, which helps us evaluate where a high-level concept might be encoded in a model's internal representations.

In particular, imagine we want to edit a model to think that Paris is in the country of Brazil, without changing whatever else the model knows about Paris (e.g., its language, continent, ...). Which representations in the model encode this fact about Paris?

In this tutorial, we'll go over **Distributed Alignment Search**, or <a href="https://arxiv.org/abs/2303.02536">DAS</a>, which helps us automatically identify a set of linear subspaces in a model's representations that encode a particular concept.

<div style="background-color:#FF9999;padding:10px 10px;border-radius:20px">
<b>Before we begin!</b>

These are good things to know before we begin the tutorial
<ul>
<li>Activation patching - check out the activation patching tutorial <a href="https://nnsight.net/notebooks/tutorials/activation_patching/">here</a>!</li>
<li>RAVEL - make sure to check out the first part of the tutorial before trying out this one!</li>
</ul>
</div>

<div style="background-color:#C1E5F5;padding:10px 10px;border-radius:20px">
<b>Things we'll talk about</b>

In case you want to tell people what you learned today!
<ul>
<li><a href="https://arxiv.org/abs/2303.02536">DAS</a> - method for finding linear subspaces of model representations that store a particular concept.</li>
</ul>

Let's do this!
</div>

In [None]:
from IPython.display import clear_output

try:
    import google.colab
    is_colab = True
except ImportError:
    is_colab = False

if is_colab:
    !pip install -U nnsight
    !git clone https://github.com/AmirZur/nnsight-tutorials.git
    %cd nnsight-tutorials/

clear_output()

## Making surgical edits - residual streams capture too much information

Last tutorial, we saw that by patching the 8th layer of the "Paris" token, we were able to change its country from France to Brazil.

![two forward runs of a model, with an arrow between the residual stream activations of Rio and Paris. After the intervention is applied, the model outputs Brazil](https://github.com/AmirZur/nnsight-tutorials/blob/main/figures/patching_visualization.png?raw=true)

In [1]:
# load model
import nnsight
from IPython.display import clear_output

model = nnsight.LanguageModel("meta-llama/Llama-3.2-1B", device_map="auto")
clear_output()

In [2]:
# does our model know where Paris is?
import torch

base_prompt = "Paris is in the country of"

with torch.no_grad():
    with model.trace(base_prompt) as tracer:
        base_tokens = tracer.invoker.inputs[0][0]['input_ids'][0]
        # Get logits from the lm_head
        base_logits = model.lm_head.output[:, -1, :].save()

base_logprobs = torch.softmax(base_logits, dim=-1)

top_completions = torch.topk(base_logprobs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
    print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')

 France (0.65)
 the (0.05)
 love (0.01)


In [3]:
# collect representations for a city from a different country
source_prompt = "Rio is in the country of"
source_country = model.tokenizer(" Brazil")["input_ids"][1] # includes a space

with torch.no_grad():
    with model.trace(source_prompt) as tracer:
        source_tokens = tracer.invoker.inputs[0][0]['input_ids'][0]
        # Get hidden states of all layers in the network.
        # We index the output at 0 because it's a tuple where the first index is the hidden state.
        source_hidden_states = [
            layer.output[0].save()
            for layer in model.model.layers
        ]

In [4]:
# by patching at layer 8 over Paris, we change its country from France to Brazil!
TOKEN_INDEX = 1
LAYER_INDEX = 8

with model.trace(base_prompt) as tracer:
    # apply the same patch we did before
    model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :] = \
        source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]

    patched_logits = model.lm_head.output[:, -1, :]

    patched_logprobs = torch.softmax(patched_logits, dim=-1).save()

top_completions = torch.topk(patched_logprobs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
    print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')

 Brazil (0.61)
 the (0.05)
 Portugal (0.01)


However, we **also accidentally edit other facts about Paris**, such as its continent and language!

In [5]:
# by changing Paris's country, we also changed its continent!
TOKEN_INDEX = 1
LAYER_INDEX = 8

new_base_prompt = "Paris is in the continent of"

with model.trace(new_base_prompt) as tracer:
    # apply the same patch we did before
    model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :] = \
        source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]

    patched_logits = model.lm_head.output[:, -1, :]

    patched_logprobs = torch.softmax(patched_logits, dim=-1).save()

top_completions = torch.topk(patched_logprobs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
    print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')

 South (0.55)
 America (0.11)
 North (0.10)


In [6]:
# as well as its language!
new_base_prompt = "Paris is a city whose main language is"

with model.trace(new_base_prompt) as tracer:
    # apply the same patch we did before
    model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :] = \
        source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]

    patched_logits = model.lm_head.output[:, -1, :]

    patched_logprobs = torch.softmax(patched_logits, dim=-1).save()

top_completions = torch.topk(patched_logprobs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
    print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')

 Portuguese (0.57)
 Spanish (0.12)
 English (0.10)


<div style="background-color:#C1E5F5;padding:10px 10px;border-radius:20px">
<b>Takeaway</b>

We need to find a way to make our patching **more precise**. One way to do this is to patch a unit of computation that's smaller than the whole residual stream component. There are many reasonable options, such as patching sets of neurons. In this tutorial, we'll look at how we can patch **linear subspaces** of a model's representation.
</div>

## Choosing the right unit of computation - how do models represent concepts?

What are we patching to begin with? Let's take a look at the source activations we collected.

In [7]:
source_activations = source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]
source_activations

tensor([[ 0.0111, -0.0206, -0.2613,  ..., -0.0281, -0.1300,  0.0346]])

In [8]:
source_activations.shape

torch.Size([1, 2048])

Can we break down the residual stream activation into smaller, meaningful units of computation?

One idea is to look at single neurons - that is, single indices within the large 2048-dimensional vector. 

Another idea, motivated by the Linear Representation Hypothesis, is that transformer-based neural networks tend to use **linear subspaces** as units of computation. Thinking about a model's activation as one giant vector, perhaps concepts are each encoded in a separate linear dimension within the vector.

![Activation represented as a linear vector, with subspaces encoding concepts such as the country & language of Paris](https://github.com/AmirZur/nnsight-tutorials/blob/main/figures/activation_vector.png?raw=true)

To patch a set of neurons, we could simply index into the ones we think encode important concepts in the model. However, enumerating all subsets of neurons is computationally infeasible.

![patching the first 3 neurons of the activations of Rio and Paris](https://github.com/AmirZur/nnsight-tutorials/blob/main/figures/patching_neurons_visualization.png?raw=true)

In [9]:
# change the list of indices to try a set of neurons to patch!
NEURON_INDICES = [0, 1, 2, 4]

base_prompt = "Paris is in the country of"

with model.trace(base_prompt) as tracer:
    # Apply the patch from the source hidden states to the base hidden states
    model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, NEURON_INDICES] = \
        source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, NEURON_INDICES]

    patched_logits = model.lm_head.output[:, -1, :]

    patched_logprobs = torch.softmax(patched_logits, dim=-1).save()

top_completions = torch.topk(patched_logprobs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
    print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')

 France (0.64)
 the (0.05)
 love (0.01)


To patch a set of **linear subspaces**, we can follow a similar procedure, with a slight twist...

First, we **rotate** our base and source vectors. This creates two new vectors, whose neurons are linear combinations of the original vector. Next, we **patch linear subspaces** just as we would in the regular set-up. Lastly, we **rotate back** the patched vector, so that it's in the same basis as the original run.

![patch between a source and base vector, where the source & base vector are first rotated. the resulting patch is then un-rotated back to the original basis](https://github.com/AmirZur/nnsight-tutorials/blob/main/figures/das_visualization.png?raw=true)

In [10]:
# construct a rotation matrix (model_dim x model_dim)
MODEL_HIDDEN_DIM = 2048

rotator = torch.nn.Linear(MODEL_HIDDEN_DIM, MODEL_HIDDEN_DIM, bias=False)
torch.nn.init.orthogonal_(rotator.weight)

rotator = torch.nn.utils.parametrizations.orthogonal(rotator)
clear_output()

In [15]:
# play around with how many linear dimensions we patch!
N_PATCHING_DIMS = 1

base_prompt = "Paris is in the country of"

def patch_linear_subspaces(rotator, base_prompt, source_hidden_states, with_grad=False):
    grad_env = torch.enable_grad if with_grad else torch.no_grad
    with grad_env():
        with model.trace(base_prompt) as tracer:
            # rotate the base representation
            base = model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :].clone()
            rotated_base = rotator(base)

            # rotate the source representation
            source = source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]
            rotated_source = rotator(source)

            # patch the first n dimensions in the rotated space
            # (NOTE: same thing as `rotated_base[:, 0] = rotated_source[:, 0]` but we want the gradient to flow)
            rotated_patch = torch.cat([
                rotated_source[:, :N_PATCHING_DIMS],
                rotated_base[:, N_PATCHING_DIMS:]
            ], dim=1)

            # unrotate patched vector back to the original space
            patch = torch.matmul(rotated_patch, rotator.weight.T)

            # replace base with patch
            model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :] = patch

            patched_logits = model.lm_head.output[:, -1, :].save()
    return patched_logits

patched_logits = patch_linear_subspaces(rotator, base_prompt, source_hidden_states, with_grad=False)
patched_logprobs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_logprobs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
    print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')

 France (0.12)
 the (0.09)
 Italy (0.02)


<div style="background-color:#F2CFEE;padding:10px 10px;border-radius:20px">
<b>Want to know more?</b>

You may have suspected this, but there's nothing particularly special about a linear rotation! Maybe the model uses the magnitude of a vector, instead of its direction, to do meaningful computation. We can think about different intermediate transformations that might expose interesting units of computation.  Here are some key properties that we need these transformations to have:
<ul>
<li><b>invertible</b> - we need to be able to "undo" the transformation to return to the original representation space from the transformed space</li>
<li><b>separable</b> - we don't want concepts to interfere with each other during the transformation</li>
</ul>

To learn about more of their properties and their theoretical grounding, check out the <a href="https://arxiv.org/abs/2301.04709">causal abstraction theory paper</a>!
</div>

Hm, changing our unit of computation from neurons to linear subspaces didn't seem to help us out much... Patching the first few linear subspaces of our rotation matrix didn't successfully edit the model's representation of Paris's country. 

How do we automatically search for the linear subspaces we care about?

<div style="background-color:#C1E5F5;padding:10px 10px;border-radius:20px">
<b>Takeaway</b>

There are different potentially meaningful units of computations in a model's representation. Thinking about the model representation as one giant multi-dimensional vector, we can try to patch **linear subspaces** of the model's representation by first rotating it to a different space. 

How do we know which linear subspaces to patch? This is where DAS comes in!
</div>

## Enter DAS - automatically finding relevant linear subspaces

By rotating the hidden representations of our model, we can patch different linear subspaces. But how can we find the right linear subspace to patch?

Turns out, we can directly optimize our rotation vector to do this! Let's try to train our rotation matrix to maximize the likelihood of "Brazil" the country of Paris instead of "France".

In [16]:
# let's train our rotation matrix so that the patch output is Brazil instead of France
from tqdm import trange

optimizer = torch.optim.Adam(rotator.parameters())

loss_fn = torch.nn.CrossEntropyLoss()

counterfactual_answer = torch.tensor([model.tokenizer(" Brazil")["input_ids"][1]])

with trange(10) as progress_bar:
    for epoch in progress_bar:
        optimizer.zero_grad()
        
        # get patched logits using our rotation vector
        patched_logits = patch_linear_subspaces(rotator, base_prompt, source_hidden_states, with_grad=True)

        # cross entropy loss - make last token be Brazil instead of France
        loss = loss_fn(patched_logits, counterfactual_answer)
        progress_bar.set_postfix({'loss': loss.item()})
        loss.backward()
        optimizer.step()

patched_logprobs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_logprobs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
    print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')

100%|██████████| 10/10 [03:44<00:00, 22.44s/it, loss=1.94]

 Brazil (0.14)
 Portugal (0.08)
 the (0.06)





Looks like training our rotation matrix did the job! Now, patching from Rio to Paris changes Paris's country from France to Brazil.

In [18]:
base_prompt = "Paris is in the country of"

patched_logits = patch_linear_subspaces(rotator, base_prompt, source_hidden_states, with_grad=False)
patched_logprobs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_logprobs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
    print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')

 Brazil (0.64)
 Portugal (0.07)
, (0.06)


But did it interfere with other facts about Paris, such as its continent or language? Doesn't look like it!

In [20]:
new_base_prompt = "Paris is in the continent of"

patched_logits = patch_linear_subspaces(rotator, new_base_prompt, source_hidden_states, with_grad=False)
patched_logprobs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_logprobs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
    print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')

 Europe (0.70)
 America (0.07)
, (0.04)


In [21]:
new_base_prompt = "Paris is a city whose main language is"

patched_logits = patch_linear_subspaces(rotator, new_base_prompt, source_hidden_states, with_grad=False)
patched_logprobs = torch.softmax(patched_logits, dim=-1)
top_completions = torch.topk(patched_logprobs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
    print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')

 spoken (0.32)
 English (0.12)
, (0.08)


<div style="background-color:#F2CFEE;padding:10px 10px;border-radius:20px">
<b>Want to know more?</b>

If there are concepts that we know we want to keep the same, we can train DAS with a multi-task objective (i.e., "edit this property" + "keep this other property the same"). See the <a href="https://arxiv.org/abs/2402.17700">RAVEL</a> paper for more detail.
</div>

<div style="background-color:#C1E5F5;padding:10px 10px;border-radius:20px">
<b>Takeaway</b>

How can we patch certain concepts in a model's representation, such as the country of Paris, without messing with other concepts stored in the model, such as Paris's continent or language?

DAS to the rescue! By searching over sets of linear subspaces, DAS finds a linear subspace in the model that, when patched, edits the model's concept. The resulting patch is more precise - by patching individual linear subspaces, we have a better chance at making sure that only the specific concept we're looking for gets edited.  
</div>