# Supervised Feature Selection: choosing the right unit of computation

By manipulating the components of a neural network in real-time, we can discover how these components represent abstract concepts. But **how do we know what unit of computation we should edit?** Should we change the entire residual stream? Only certain attention heads? Or the individual neurons themselves?

In this tutorial, we'll look at the problem of **supervised feature selection**, or finding components in the model that localize specific concepts. We'll build up to Distributed Alignment Search (DAS), which automatically searches for a set of linear subspaces that represent a particular concept.

<div style="background-color:#FF9999;padding:10px 10px;border-radius:20px">
<b>Before we begin!</b>

These are good things to know before we begin the tutorial
<ul>
<li>Activation patching - check out the activation patching tutorial <a href="https://nnsight.net/notebooks/tutorials/activation_patching/">here</a>!</li>
</ul>
</div>

<div style="background-color:#C1E5F5;padding:10px 10px;border-radius:20px">
<b>Things we'll talk about</b>

In case you want to tell people what you learned today!
<ul>
<li><a href="https://arxiv.org/abs/2402.17700">RAVEL</a> - how to evaluate that we're only editing the concept we're looking for.</li>
<li><a href="https://arxiv.org/abs/2303.02536">DAS</a> - method for finding linear subspaces of model representations that store a particular concept.</li>
</ul>

Let's do this!
</div>

In [None]:
from IPython.display import clear_output

try:
    import google.colab
    is_colab = True
except ImportError:
    is_colab = False

if is_colab:
    !pip install -U nnsight
    !git clone https://github.com/AmirZur/nnsight-tutorials.git
    %cd nnsight-tutorials/

clear_output()

## Activation Patching Review - editing a model's knowledge of geography

We'll begin with a quick review of activation patching.

When we tell an LLM we want to travel to Paris, the LLM knows that Paris is in France. Where is that information stored? In this notebook, we'll try to find out **where a language model stores the country information for a particular city.**

In [1]:
# load model
import nnsight
from IPython.display import clear_output

model = nnsight.LanguageModel("meta-llama/Llama-3.2-1B", device_map="auto")
clear_output()

In [2]:
# does our model know where Paris is?
import torch

base_prompt = "Paris is in the country of"

with torch.no_grad():
    with model.trace(base_prompt) as tracer:
        base_tokens = tracer.invoker.inputs[0][0]['input_ids'][0]
        # Get logits from the lm_head
        base_logits = model.lm_head.output[:, -1, :].save()

base_logprobs = torch.softmax(base_logits, dim=-1)

top_completions = torch.topk(base_logprobs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
    print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')

 France (0.65)
 the (0.05)
 love (0.01)


Looks like our model knows that Paris is in France. Now let's make the model think that Paris is elsewhere! We can do this by **editing** the model's representation of Paris during its computation.

![two forward runs of a model, with an arrow between the residual stream activations of Rio and Paris. After the intervention is applied, the model outputs Brazil](https://github.com/AmirZur/nnsight-tutorials/blob/main/figures/patching_visualization.png?raw=true)

See the tutorial on <a href="https://nnsight.net/notebooks/tutorials/activation_patching/">activation patching</a> for more details!

In [3]:
# collect representations for a city from a different country
source_prompt = "Rio is in the country of"
source_country = model.tokenizer(" Brazil")["input_ids"][1] # includes a space

with torch.no_grad():
    with model.trace(source_prompt) as tracer:
        source_tokens = tracer.invoker.inputs[0][0]['input_ids'][0]
        # Get hidden states of all layers in the network.
        # We index the output at 0 because it's a tuple where the first index is the hidden state.
        source_hidden_states = [
            layer.output[0].save()
            for layer in model.model.layers
        ]

This will take 1-2min, since we're going through all layers and tokens in our prompt.

In [4]:
# activation patching intervention
patching_results = []

# iterate through all the layers
for layer_idx in range(len(model.model.layers)):
    patching_results_per_layer = []

    # iterate through all tokens
    for token_idx in range(len(base_tokens)):
        with model.trace(base_prompt) as tracer:
            # apply the patch from the source hidden states to the base hidden states
            model.model.layers[layer_idx].output[0][:, token_idx, :] = \
                source_hidden_states[layer_idx][:, token_idx, :]

            patched_logits = model.lm_head.output[:, -1, :]

            patched_logprobs = torch.softmax(patched_logits, dim=-1)

            patching_results_per_layer.append(patched_logprobs[0, source_country].item().save())

    patching_results.append(patching_results_per_layer)

In [5]:
import plotly.express as px
from nnsight.tracing.graph import Proxy

def plot_patching_results(
    patching_results,
    base_tokens,
    source_tokens
):
    # get values from proxy variables
    patching_results = nnsight.util.apply(patching_results, lambda x: x.value, Proxy)

    fig = px.imshow(
        patching_results,
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        labels={"x": "token", "y": "layer","color":"counterfactual logit"},
        x=[f"<span style=\"color: #156082\">{b}</span><br></br><span style=\"color: #E97132\">{s}</span>" for b, s in zip(base_tokens, source_tokens)]
    )

    return fig

In [6]:
decoded_base_tokens = [model.tokenizer.decode(token) for token in base_tokens]
decoded_source_tokens = [model.tokenizer.decode(token) for token in source_tokens]

plot_patching_results(patching_results, decoded_base_tokens, decoded_source_tokens)

<div style="background-color:#C1E5F5;padding:10px 10px;border-radius:20px">
<b>Takeaway</b>

Look at the plot above. What can you tell from it about where the model represents the country of Paris?

Notice how changing the model's activations over the Paris token makes a different up until a certain layer. However, once we hit that layer, the information immediately gets transferred to the last token in the prompt.

It's likely then that the model "read" the country information of Paris around **Layer 8**. We'll focus the rest of our analysis on this layer.
</div>

## Our Challenge - disentangling concepts in a model's representations

Our patching experiments did the job! By swapping 8th layer's activation for the Paris token with the activation for Rio, we successfully changed the country in question from France to Brazil.

But **how can we be sure that it's really the country that we edited?** It's actually much likelier that we changed the **entire city in question**. If we ask the model follow-up questions, we'll see that we changed not only the country that Paris is in, but also its continent, language, food, and other concepts that we didn't look for...

In [7]:
# by changing Paris's country, we also changed its continent!
TOKEN_INDEX = 1
LAYER_INDEX = 8

new_base_prompt = "Paris is in the continent of"

with model.trace(new_base_prompt) as tracer:
    # apply the same patch we did before
    model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :] = \
        source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]

    patched_logits = model.lm_head.output[:, -1, :]

    patched_logprobs = torch.softmax(patched_logits, dim=-1).save()

top_completions = torch.topk(patched_logprobs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
    print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')

 South (0.55)
 America (0.11)
 North (0.10)


In [8]:
# as well as its language!
new_base_prompt = "Paris is a city whose main language is"

with model.trace(new_base_prompt) as tracer:
    # apply the same patch we did before
    model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, :] = \
        source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]

    patched_logits = model.lm_head.output[:, -1, :]

    patched_logprobs = torch.softmax(patched_logits, dim=-1).save()

top_completions = torch.topk(patched_logprobs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
    print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')

 Portuguese (0.57)
 Spanish (0.12)
 English (0.10)


<div style="background-color:#C1E5F5;padding:10px 10px;border-radius:20px">
<b>Takeaway</b>

We did too much! 

While it's maybe not too surprising, when we intervene on the full residual stream over Paris, we edit more than just the country that it's in. But we know that the model must somehow access the country of Paris from its memory. Can we edit Paris's country without changing its continent or language?
</div>

<div style="background-color:#F2CFEE;padding:10px 10px;border-radius:20px">
<b>Want to know more?</b>

Our example comes directly from the <a href="https://arxiv.org/abs/2402.17700">RAVEL</a> dataset, which evaluates methods for disentangling language model representations. We won't explore the entire RAVEL dataset or methods evaluated in this tutorial, so we strongly encourage you to go check it out!
</div>

## Looking Ahead - what's the right unit to patch?

Our goal is to surgically edit the country of Paris without changing any of its other properties. What if by patching the entire residual stream activation we made too big of a cut?

What are we actually patching? If we look at the residual stream, we see a list of numbers, which we can think of as a set of neurons.

In [9]:
source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :]

tensor([[ 0.0111, -0.0206, -0.2613,  ..., -0.0281, -0.1300,  0.0346]])

In [10]:
source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, :].shape

torch.Size([1, 2048])

Can we intervene on just a handful of these neurons?

![patching the first 3 neurons of the activations of Rio and Paris](https://github.com/AmirZur/nnsight-tutorials/blob/main/figures/patching_neurons_visualization.png?raw=true)

In [11]:
# change the list of indices to try a set of neurons to patch!
NEURON_INDICES = [0, 1, 2, 4]

base_prompt = "Paris is in the country of"

with model.trace(base_prompt) as tracer:
    # Apply the patch from the source hidden states to the base hidden states
    model.model.layers[LAYER_INDEX].output[0][:, TOKEN_INDEX, NEURON_INDICES] = \
        source_hidden_states[LAYER_INDEX][:, TOKEN_INDEX, NEURON_INDICES]

    patched_logits = model.lm_head.output[:, -1, :]

    patched_logprobs = torch.softmax(patched_logits, dim=-1).save()

top_completions = torch.topk(patched_logprobs, 3, sorted=True)
for v, i in zip(top_completions.values[0], top_completions.indices[0]):
    print(f'{model.tokenizer.decode(i.item())} ({v.item():.2f})')

 France (0.64)
 the (0.05)
 love (0.01)


This didn't quite work - the first four neurons aren't enough to edit France's country.

We can try to search over the neurons in the residual stream, but individual neurons are likely too small to make a difference, and there are far too many sets of neurons to search through...

In the next section, we'll think about a different unit of computation - a **linear subspace** in the model - and go through a method that automatically searches for the linear subspace that represents the concept we're looking for. 

<div style="background-color:#F2CFEE;padding:10px 10px;border-radius:20px">
<b>Want to know more?</b>

Although we left it at "there are too many sets of neurons to search through", there actually is a method for automatically selecting a subset of neurons!
The method uses a differentiable binary mask (DBM) to find the right set of neurons to intervene on.

In the next section we'll go over DAS, whose search space encompasses DBM and which performs better on the RAVEL dataset.

While we won't cover DBM in this tutorial, we strongly encourage you to <a href="https://dcm.baulab.info/">check it out</a>!

</div>