In [4]:
import torch
from taker import Model

m = Model("google/gemma-2b-it", dtype="int4", limit=1000)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are calling `save_pretrained` to a 4-bit converted model, but your `bitsandbytes` version doesn't support it. If you want to save 4-bit models, make sure to have `bitsandbytes>=0.41.3` installed.


Loaded model 'google/gemma-2b-it' with int4:
- Added 288 hooks across 18 layers


In [6]:
m.generate("Hello world in python:", 100)

('Hello world in python:',
 '\n\n```python\nprint("Hello world")\n```\n\n**Output:**\n\n```\nHello world\n```\n\n**Explanation:**\n\n1. `print()` is a built-in Python function that prints the given argument to the console.\n2. `"Hello world"` is the argument that we are passing to `print()`.\n3. The `print()` function will call the `__str__()` method of the string `"Hello world"` and print the output.\n\n**')

In [11]:
res = m.get_residual_stream("Hello world!")
print(res.shape)

res = m.get_residual_stream_decoder("Hello world!")
print(res.shape)

act = m.get_midlayer_activations("Hello world!")
print(act["mlp"].shape, act["attn"].shape)

torch.Size([1, 37, 4, 2048])
torch.Size([1, 19, 4, 2048])
torch.Size([1, 18, 4, 16384]) torch.Size([1, 18, 4, 8, 256])


In [15]:
m.hooks.disable_all_collect_hooks()

m.hooks.enable_collect_hooks(["attn_pre_out", "pre_decoder"], layers=[2,5])

m.get_outputs_embeds("This is some example text")

print( m.hooks["attn_pre_out"]["collect"] )
print( m.hooks["pre_decoder"]["collect"] )

[None, None, tensor([[[[ 9.4727e-02, -1.3330e-01, -6.3354e-02,  ...,  1.0071e-01,
            9.3506e-02, -1.0907e-01],
          [ 9.4727e-02, -1.3330e-01, -6.3354e-02,  ...,  1.0071e-01,
            9.3506e-02, -1.0907e-01],
          [ 9.4727e-02, -1.3330e-01, -6.3354e-02,  ...,  1.0071e-01,
            9.3506e-02, -1.0907e-01],
          ...,
          [ 9.4727e-02, -1.3330e-01, -6.3354e-02,  ...,  1.0071e-01,
            9.3506e-02, -1.0907e-01],
          [ 9.4727e-02, -1.3330e-01, -6.3354e-02,  ...,  1.0071e-01,
            9.3506e-02, -1.0907e-01],
          [ 9.4727e-02, -1.3330e-01, -6.3354e-02,  ...,  1.0071e-01,
            9.3506e-02, -1.0907e-01]],

         [[ 7.1838e-02, -1.4087e-01, -8.5022e-02,  ...,  9.3323e-02,
            9.8633e-02, -1.0114e-01],
          [-3.5864e-01, -2.8320e-01, -4.9292e-01,  ..., -4.5044e-02,
            1.9543e-01,  4.8096e-02],
          [-1.2002e+00, -5.6152e-01, -1.2900e+00,  ..., -3.1543e-01,
            3.8452e-01,  3.3984e-01],
       

In [16]:
print( m.hooks["attn_pre_out"]["collect"] )
print( m.hooks["pre_decoder"]["collect"] )


[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]


In [24]:
# We get the residual stream with no changes
m.hooks.neuron_replace["layer_0_mlp_pre_out"].reset()
res1 = m.get_residual_stream("Some text here")[:,:4,:3,:5]
print(res)


tensor([[[[ 1.0352e-01,  5.0354e-03, -3.2715e-02,  ..., -1.7334e-02,
           -8.9111e-03, -1.0254e-02],
          [ 3.0859e-01, -5.3223e-02,  1.6357e-02,  ...,  3.4180e-03,
            1.0010e-01, -2.4292e-02],
          [ 2.3340e-01, -9.7656e-02, -2.8687e-03,  ..., -3.3936e-02,
            6.4941e-02, -1.4258e-01],
          [ 1.9922e-01,  4.2236e-02, -1.4453e-01,  ...,  9.3384e-03,
            1.9531e-02, -2.8931e-02]],

         [[ 3.7383e+00,  4.4922e-02, -1.0879e+00,  ...,  3.7109e-01,
            2.3010e-02,  1.7102e-01],
          [ 3.9023e+00, -2.1836e+00,  1.1094e+00,  ...,  1.1309e+00,
            8.0078e-01, -4.5044e-01],
          [ 1.0137e+00, -1.2432e+00,  1.7749e-01,  ..., -1.2676e+00,
            3.9062e-02, -2.8203e+00],
          [ 1.5830e+00,  9.1016e-01, -1.8301e+00,  ...,  2.9297e-03,
           -1.0089e-01, -9.0137e-01]],

         [[ 3.4922e+00,  3.7048e-02, -9.6045e-01,  ...,  5.1514e-01,
            4.0466e-02,  2.9614e-01],
          [ 5.8594e-01, -1.0215e+

In [26]:
# We get the residual stream when we change:
# - layer 0 MLP, token 1, to all zeros
m.hooks.neuron_replace["layer_0_mlp_pre_out"].add_token(
    1, torch.zeros([m.cfg.d_mlp])
)
res2 = m.get_residual_stream("Some text here")[:,:4,:3,:5]
print((res2-res1)/res1)

Got 4, already saw 0, checking 2 out of max 2
tensor([[[[ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000]],

         [[ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000]],

         [[ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 2.0508,  0.0182,  6.1211, -2.1406, -0.3552],
          [ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000]],

         [[ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 2.7734, -0.0179, -0.0384, -1.0674,  1.4033],
          [-1.6133, -0.0564, -0.1362, -0.1129, -0.1021]]]], device='cuda:0',
       dtype=torch.float16, grad_fn=<DivBackward0>)


In [38]:
offset_layer_0 = m.hooks.neuron_offsets["layer_0_mlp_pre_out"]
offset_layer_0.param.data = torch.zeros_like(offset_layer_0.param.data)
res1 = m.get_residual_stream(".")[..., :3, :, :5]
mlp1 = m.get_midlayer_activations(".")["mlp"][..., :5]
print(res1)

offset_layer_0.param.data = torch.ones_like(offset_layer_0.param.data) * 0.01
print(offset_layer_0.state_dict())
res2 = m.get_residual_stream(".")[..., :3, :, :5]
mlp2 = m.get_midlayer_activations(".")["mlp"][..., :5]

print((res1-res2)/res1)
print((mlp1-mlp2)/mlp1)


Got 2, already saw 0, checking 2 out of max 2
Got 2, already saw 0, checking 2 out of max 2
tensor([[[[ 4.6836e+00,  2.2791e-01, -1.4805e+00,  6.5247e-02, -7.5098e-01],
          [ 8.3516e+00, -2.8281e+00, -8.0859e+00,  1.6465e+00, -1.2539e+00]],

         [[ 3.5859e+00,  2.3853e-01, -7.8125e-01,  7.0801e-03, -2.4658e-01],
          [ 4.5234e+00, -1.7100e+00, -4.1680e+00,  1.5674e+00, -8.9844e-02]],

         [[ 3.7383e+00,  4.4922e-02, -1.0879e+00,  2.9956e-01, -1.2048e-01],
          [ 4.5234e+00, -1.7100e+00, -4.1680e+00,  1.5674e+00, -8.9844e-02]]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<SliceBackward0>)
OrderedDict([('param', tensor([0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100], device='cuda:0',
       dtype=torch.float16))])
Got 2, already saw 0, checking 2 out of max 2
Got 2, already saw 0, checking 2 out of max 2
tensor([[[[ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000]],

         [[ 0.0000,  0.000

In [39]:
res1 = m.get_residual_stream(".")[..., :3, :, :5]
mlp1 = m.get_midlayer_activations(".")["mlp"][..., :5]
print(res1)

remove_indices = torch.zeros([m.cfg.n_layers, m.cfg.d_mlp], dtype=bool)
remove_indices[0, 1] = True
m.hooks.delete_mlp_neurons(remove_indices)

print(offset_layer_0.state_dict())
res2 = m.get_residual_stream(".")[..., :3, :, :5]
mlp2 = m.get_midlayer_activations(".")["mlp"][..., :5]

print((res1-res2)/res1)
print((mlp1-mlp2)/mlp1)

Got 2, already saw 0, checking 2 out of max 2
Got 2, already saw 0, checking 2 out of max 2
tensor([[[[ 4.6836e+00,  2.2791e-01, -1.4805e+00,  6.5247e-02, -7.5098e-01],
          [ 8.3516e+00, -2.8281e+00, -8.0859e+00,  1.6465e+00, -1.2539e+00]],

         [[ 3.5859e+00,  2.3853e-01, -7.8125e-01,  7.0801e-03, -2.4658e-01],
          [ 4.5234e+00, -1.7100e+00, -4.1680e+00,  1.5674e+00, -8.9844e-02]],

         [[ 3.7383e+00,  4.4922e-02, -1.0879e+00,  2.9932e-01, -1.2048e-01],
          [ 4.5352e+00, -1.7158e+00, -4.1445e+00,  1.5830e+00, -8.9905e-02]]]],
       device='cuda:0', dtype=torch.float16, grad_fn=<SliceBackward0>)
OrderedDict([('param', tensor([0.0100, 0.0100, 0.0100,  ..., 0.0100, 0.0100, 0.0100], device='cuda:0',
       dtype=torch.float16))])
Got 2, already saw 0, checking 2 out of max 2
Got 2, already saw 0, checking 2 out of max 2
tensor([[[[ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000]],

         [[ 0.0000,  0.000