# Token Based Indexing

When indexing hidden states for specific tokens, use `.token[<idx>]` or `.t[<idx>]`.

As a preliminary example, lets just get a hidden state from the model using `.t[<idx>]`.

In [1]:
from nnsight import LanguageModel

model = LanguageModel('gpt2', device_map='cuda')

In [2]:
with model.forward() as runner:
    with runner.invoke('The Eiffel Tower is in the city of') as invoker:

        hidden_states = model.transformer.h[-1].output[0].t[0].save()

output = runner.output
hidden_states = hidden_states.value

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Lets see why token based indexing is necessary.

In this example, we call invokes on two inputs of different tokenized length. We **incorrectly** index into the hidden states using normal python indexing.

In [3]:
from rich import print

with model.forward() as runner:
    with runner.invoke('The') as invoker:
        incorrect_a =  model.transformer.input[0][0][:,0].save()
        
    with runner.invoke('The Eiffel Tower is in the city of''The Eiffel Tower is in the city of') as invoker:
        incorrect_b = model.transformer.input[0][0][:,0].save()

print(f"Shorter input: {incorrect_a.value}")
print(f"Longer input: {incorrect_b.value}")

Notice how we indexed into the first token for both strings but recieved a different result from each invoke. **This is because if there are multiple invocations, padding is performed on the left side so these helper functions index from the back.**

Let's correctly index into the hidden states using token based indexing.

In [4]:
with model.forward() as runner:
    with runner.invoke('The') as invoker:
        correct_a =  model.transformer.input[0][0].t[0].save()
        
    with runner.invoke('The Eiffel Tower is in the city of') as invoker:
        correct_b = model.transformer.input[0][0].t[0].save()

print(f"Shorter input: {correct_a.value}")
print(f"Longer input: {correct_b.value}")

Now we have the correct tokens!