In [2]:
import nnsight
import transformers
import torch
import tqdm

In [3]:
from collections import OrderedDict

input_size = 5
hidden_dims = 10
output_size = 2

net = torch.nn.Sequential(
    OrderedDict(
        [
            ("layer1", torch.nn.Linear(input_size, hidden_dims)),
            ("layer2", torch.nn.Linear(hidden_dims, output_size)),
        ]
    )
).requires_grad_(False)

In [4]:
from nnsight import NNsight

tiny_model = NNsight(net)

In [5]:
print(tiny_model)

Sequential(
  (layer1): Linear(in_features=5, out_features=10, bias=True)
  (layer2): Linear(in_features=10, out_features=2, bias=True)
)


In [6]:
print(net)

Sequential(
  (layer1): Linear(in_features=5, out_features=10, bias=True)
  (layer2): Linear(in_features=10, out_features=2, bias=True)
)


We enter the tracing context by calling `model.trace(<input>)` on an NNsight model

In [7]:
input = torch.rand((1, input_size))

with tiny_model.trace(input) as tracer:
    pass

In [8]:
with tiny_model.trace(input) as tracer:
    output = tiny_model.output.save()
print(output)

tensor([[0.1626, 0.4687]])


In [9]:
with tiny_model.trace(input) as tracer:
    l1_output = tiny_model.layer1.output.save()
print(l1_output)

tensor([[-0.2155,  0.3026, -0.2610, -0.2924,  0.8791, -0.3922,  0.6998, -0.2105,
         -0.4250, -0.3945]])


In [12]:
with tiny_model.trace(input):
    l2_input = tiny_model.layer2.input.save()
print(l2_input)

tensor([[-0.2155,  0.3026, -0.2610, -0.2924,  0.8791, -0.3922,  0.6998, -0.2105,
         -0.4250, -0.3945]])


In [14]:
with tiny_model.trace(input):
    l1_output = tiny_model.layer1.output
    l1_amax = torch.argmax(l1_output, dim=1).save()
print(l1_amax[0])

tensor(4)


In [15]:
with tiny_model.trace(input):
    value = (tiny_model.layer1.output.sum() + tiny_model.layer2.output.sum()).save()
print(value)

tensor(0.3216)


`nnsight.apply()` allows us to add new functions to the intervention graph.

In [16]:
# Take a tensor and return the sum of its elements
def tensor_sum(tensor):
    flat = tensor.flatten()
    total = 0
    for element in flat:
        total += element.item()

    return torch.tensor(total)

with tiny_model.trace(input) as tracer:
    custom_sum = nnsight.apply(tensor_sum, tiny_model.layer1.output).save()
    sum = tiny_model.layer1.output.sum()
    sum.save()

print(custom_sum, sum)

tensor(-0.3097) tensor(-0.3097)


In [22]:
from scipy.stats import entropy

with tiny_model.trace(input) as tracer:
    l1_output = tiny_model.layer1.output.save()
    l1_entropy = nnsight.apply(entropy, l1_output).save()
print(l1_output)
print(l1_entropy)

tensor([[-0.2155,  0.3026, -0.2610, -0.2924,  0.8791, -0.3922,  0.6998, -0.2105,
         -0.4250, -0.3945]])
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]


In [23]:
with tiny_model.trace(input):
    # Save the output before the edit to compare, applying .clone()
    # before saving as the setting operation is in place.
    l1_output_before = tiny_model.layer1.output.clone().save()

    # Access the -th index of the hidden state dimension and set it
    # to zero
    tiny_model.layer1.output[:, 0] = 0

    # Save the output after to see our edit
    l1_output_after = tiny_model.layer1.output.save()

print("Before:", l1_output_before)
print("After", l1_output_after)

Before: tensor([[-0.2155,  0.3026, -0.2610, -0.2924,  0.8791, -0.3922,  0.6998, -0.2105,
         -0.4250, -0.3945]])
After tensor([[ 0.0000,  0.3026, -0.2610, -0.2924,  0.8791, -0.3922,  0.6998, -0.2105,
         -0.4250, -0.3945]])


In [24]:
# Early Stopping
with tiny_model.trace(input):
    l1_out = tiny_model.layer1.output.save()
    tiny_model.layer1.output.stop()

# get the output of the first layer and stop tracing
print(l1_out)

tensor([[-0.2155,  0.3026, -0.2610, -0.2924,  0.8791, -0.3922,  0.6998, -0.2105,
         -0.4250, -0.3945]])


In [26]:
with tiny_model.trace(input) as tracer:
    rand_int = torch.randint(low=-10, high=10, size=(1,))
    with tracer.cond(rand_int % 2 == 0):
        tracer.log(rand_int, "is even")
    with tracer.cond(rand_int % 2 == 1):
        tracer.log(rand_int, "is odd")


tensor([-10]) is even


In [27]:
with tiny_model.trace(input) as tracer:
    non_rand_int = 8
    with tracer.cond(non_rand_int > 0):
        with tracer.cond(non_rand_int % 2 == 0):
            tracer.log(non_rand_int, "is positive and even")

8 is positive and even


In [30]:
with tiny_model.trace(input) as tracer:
    rand_int = torch.randint(low=-10, high=10, size=(1,))
    if rand_int % 2 == 0:
        tracer.log("Random Integer", rand_int, "is even")
    else:
        tracer.log("Random Integer", rand_int, "is odd")

Random Integer tensor([-10]) is even


In [31]:
with tiny_model.session() as session:

    li = nnsight.list()
    [li.append([num]) for num in range(0, 3)]
    li2 = nnsight.list().save()

    # You can create nested Iterator contexts
    with session.iter(li) as item:
        with session.iter(item) as item_2:
            li2.append(item_2)

print("\nList:", li2)


List: [0, 1, 2]


In [37]:
with tiny_model.session() as session:
    li = nnsight.list()
    [li.append([num]) for num in range(0, 3)]
    li2 = nnsight.list().save()

    for item in li:
        session.log(item)
        for item2 in item:
            session.log(item2)
            li2.append(item2)

print(li2)

[0]
0
[1]
1
[2]
2
[0, 1, 2]


In [38]:
from nnsight import LanguageModel

llm = LanguageModel("openai-community/gpt2", device_map="auto")

print(llm)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  (generator): Generator(
    (streamer): Streamer()
  )
)


In [55]:
import torch.nn.functional as F

def calc_entropy(probs, axis=1):
    return entropy(probs, axis=axis)

with llm.trace("ABCDEFGHIJKLMNOPQRSTUVWXY"):
    # Access the last layer using h[-1] as it's a ModuleList
    # Access the first index of .output as that's where the
    # hidden states are.
    llm.transformer.h[-1].mlp.output[0][:] = 0

    # Get the logits
    token_ids = llm.lm_head.output.save()

    # Convert logits to probabilities using softmax along vocab dimension
    probs = F.softmax(token_ids, dim=-1)

    # Calculate entropy on the probability distribution of next tokens
    pred_entropy = nnsight.apply(calc_entropy, probs.detach().cpu()).save()
    pred_token = token_ids.argmax(dim=-1).save()

print("Token IDs:", token_ids)
print("Next Token Entropy:", pred_entropy[0][-1])
print("Argmax:", pred_token)
print("Token:", llm.tokenizer.decode(pred_token[0][-1]))

Token IDs: tensor([[[ -27.5993,  -27.1712,  -29.4423,  ...,  -34.4424,  -34.2497,
           -27.0451],
         [ -72.0859,  -70.8134,  -75.3755,  ...,  -84.4147,  -81.6738,
           -72.0996],
         [ -85.3450,  -84.3682,  -90.1620,  ...,  -98.0732,  -96.9377,
           -86.8543],
         ...,
         [ -82.0589,  -77.7474,  -83.8439,  ...,  -94.5601,  -90.0955,
           -78.5493],
         [ -94.8259,  -91.0452,  -97.5826,  ..., -107.2448, -105.5906,
           -91.2598],
         [ -86.3132,  -81.4123,  -83.1329,  ...,  -98.7171,  -96.0636,
           -82.1496]]], device='mps:0', grad_fn=<LinearBackward0>)
Next Token Entropy: 1.2061377
Argmax: tensor([[   11,    13, 23852,    42, 31288,    45,  3185,    48,    49,  2257,
            52, 30133, 34278,    57]], device='mps:0')
Token: Z


When we call .trace() it's actually creating two different contexts behind the scenes. The first one is the tracing context, the second is the invoker context. The incoker context defines the values of the input and output Proxies.

If we call .trace() with some input, the input is passed on to the invoker. As there is only one input, only one invoker context is created.

If we call .trace() without an input, then we call tracer.invoke(input1) to manually create the invoker context with an input, input1. We can also repeatedly call tracer.invoke() to create the invoker context for additional inputs. Every subsequent time we call .invoke(), interventions within its context will only refer to the input in that particular invoke statement.

When exiting the tracing context, the inputs from all of the invokers will be batched together, and they will be executed in one forward pass. To test this out, let's do the same ablation experiment, but also add a control output for comparison.

In [60]:
with llm.trace() as tracer:
    with tracer.invoke("The Eiffel Tower is in the city of"):
        # Ablate the last MLP for only this batch
        llm.transformer.h[-1].mlp.output[0][:] = 0
        # Get the output for only the intervened on batch
        token_ids_intervention = llm.lm_head.output.argmax(dim=-1).save()
    with tracer.invoke("The Eiffel Tower is in the city of"):
        # Get the output for only the original batch
        token_ids_original = llm.lm_head.output.argmax(dim=-1).save()
print("Original token IDs:", token_ids_original)
print("Modified token IDs:", token_ids_intervention)
print("Original prediction:", llm.tokenizer.decode(token_ids_original[0][-1]))
print("Modified prediction:", llm.tokenizer.decode(token_ids_intervention[0][-1]))

Original token IDs: tensor([[ 198,   12,  417, 8765,  318,  257,  262, 3504, 7372, 6342]],
       device='mps:0')
Modified token IDs: tensor([[ 262,   12,  417, 8765,   11,  257,  262, 3504,  338, 3576]],
       device='mps:0')
Original prediction:  Paris
Modified prediction:  London


In [None]:
with llm.generate('The Eiffel Tower is in the city of', max_new_tokens=3) as tracer:
    hidden_states1 = llm.transformer.h[-1].output[0].save()
    # use module.next() to access the next intervention
    hidden_states2 = llm.transformer.h[-1].next().output[0].save()
    # saving the output allows you to save the hidden state across the initial prompt
    out = llm.generator.output.save()
print(hidden_s)