# DAS Puzzle! Mystery toy model

In the last tutorial, we went over <a href="https://arxiv.org/abs/2303.02536">DAS</a> and used it to identify a linear subspace in an LLM that stored information about the country of Paris.

Now it's time to put your learning to the test, and investigate a mystery neural network!

In [None]:
from IPython.display import clear_output

try:
    import google.colab
    is_colab = True
except ImportError:
    is_colab = False

if is_colab:
    !pip install -U nnsight
    !git clone https://github.com/AmirZur/nnsight-tutorials.git
    %cd nnsight-tutorials/

clear_output()

## Exploring our mystery model

Our mystery model simulates the RAVEL example we went over in the previous tutorial. The model takes in a sentence with two words,
* **city**: "Paris", "Rome", or "Seattle"
* **property**: "language", "food", "country", or "?"

and outputs the chosen city's specified property.

In [1]:
import nnsight
import torch
from puzzle_utils import construct_mystery_model

mystery_model = nnsight.NNsight(construct_mystery_model())

In [3]:
# play around with the model!
print(mystery_model("Paris language"))
print(mystery_model("Rome food"))
print(mystery_model("Seattle country"))

French
pizza
USA


The model architecture is as follows:
* **city embedding** (`entity_embed`): linear embedding of the city token (first word).
* **property embedding** (`property_embed`): linear embedding of the property token (second word).
* **hidden layer** (`hidden_layer`): linear layer over the concatenated city and property embeddings.
* **output layer** (`out_head`): classification head for output token.

![Layer 1: 6-neuron vector over the token Paris, 6-neuron vector over the token Language. Layer 2: 6-neuron vector. Layer 3: single neuron, decoded to French.](https://github.com/AmirZur/nnsight-tutorials/blob/main/figures/mystery_model_architecture.png?raw=true)

In [2]:
mystery_model

MysteryModel(
  (entity_embed): Linear(in_features=3, out_features=6, bias=False)
  (property_embed): Linear(in_features=4, out_features=6, bias=False)
  (hidden_layer): Sequential(
    (linear): Linear(in_features=12, out_features=6, bias=True)
    (activation): ReLU()
  )
  (out_head): Linear(in_features=6, out_features=1, bias=False)
)

<div style="background-color:#C1E5F5;padding:10px 10px;border-radius:20px">
<b>Note</b>

In this tutorial, we'll focus on the <b>city embedding</b>. Follow the code below to inspect the city embeddings for different city tokens. 
</div>

In [4]:
# use nnsight to read the model's representation of Rome
source_prompt = "Rome country"

with torch.no_grad():
    with mystery_model.trace(source_prompt):
        source_activations = mystery_model.entity_embed.output.save()
        source_output = mystery_model.output.save()

# entity_embed
print('Output:', source_output)
source_activations

Output: Italy


tensor([ 0., 12., 14., 18.,  0., 16.])

## Puzzle #1: Parsing the city embeddings

Cities are embedded in a 6-dimensional vector. **Each dimension corresponds to a property**, such as the city's language, food, or country. Your puzzle is to figure out how these properties are encoded.

<div style="background-color:#C1E5F5;padding:10px 10px;border-radius:20px">
<b>Puzzle #1</b>

**Your mission, should you choose to accept it**: find out which dimension corresponds to which property. To get you started, we set up a single-dimension version of DAS as in the previous tutorial.
</div>

In [5]:
# NOTE: this is a faster variant of DAS.
# instead of creating a full (model dim x model dim) rotation matrix, we focus only on the dimensions we care about with a rotation matrix of size (model dim x # patched subspaces)
N_PATCHING_DIMS = 1

one_dim_rotator = torch.nn.Linear(mystery_model.hidden_dim, N_PATCHING_DIMS, bias=False)
one_dim_rotator.weight.data = torch.tensor([[1., 0., 0., 0., 0., 0.]])

In [6]:
with mystery_model.trace("Paris language") as tracer:
    base = mystery_model.entity_embed.output.clone().save()

    one_dim_rotation_vector = one_dim_rotator.weight
    # get rotated dims from base
    rotated_base = torch.matmul(
        one_dim_rotation_vector, base
    )

    # get rotated dims from source
    source = source_activations
    rotated_source = torch.matmul(
        one_dim_rotation_vector, source
    )

    # remove the rotated base & replace it with the rotated source 
    rotated_patch = rotated_source - rotated_base

    # unrotate patched vector
    patch = base + torch.matmul(
        one_dim_rotation_vector.T, rotated_patch
    )

    # replace base with patch
    mystery_model.entity_embed.output = patch

    # token id of output (with gradient)
    patched_output = mystery_model.out_head.output.save()
    # decoded token of output (no gradient)
    patched_answer = mystery_model.output.save()

print("Output token ID:", patched_output.item())
print("Token ID for \"Italian\":", mystery_model.tokenizer.tokenize("Italian", type="output"))
print("Output token:", patched_answer)

Output token ID: 1.0
Token ID for "Italian": 5
Output token: French


<div style="background-color:#C1E5F5;padding:10px 10px;border-radius:20px">
<b>Puzzle #1</b>

Can you find the right dimension to change the patched output from "French" to "Italian"? How about for the other properties (country and food)?

Create a map from concept ("language", "food", and "country") to dimension (either one-hot vector or index of the neuron that stores the concept).

<i>Hint: applying a softmax to the rotation vector can help reveal the single most important dimension!</i>
</div>

In [7]:
# your code here!
pass

## Puzzle #2: what's the missing property?

We haven't revealed everything to you quite yet! There's a mysterious property - **"?"** - that retrieves a city feature. However, right now it isn't working...

In [8]:
# something's wrong...
mystery_model("Seattle ?")

'France'

<div style="background-color:#C1E5F5;padding:10px 10px;border-radius:20px">
<b>Puzzle #2</b>

**Your mission, should you choose to accept it**: using your knowledge from the previous puzzle, inspect the **property embeddings** of the model. Can you **debug the "?" embedding**? What property was it supposed to encode?
</div>

In [9]:
# your code here!
pass