In [1]:
# some notes from SAE tutorial on gemmascope https://colab.research.google.com/drive/17dQFYUYnuKnP6OwQPH9v_GSYUW5aj-Rp#scrollTo=12wF3f7o1Ni7
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np

# this is what you do if you only want inference (not training)
# saves on memory usage
torch.set_grad_enabled(False) 

  from .autonotebook import tqdm as notebook_tqdm


<torch.autograd.grad_mode.set_grad_enabled at 0x14afb8170>

In [2]:
# every transformer has a tokenizer, so we load the one for the model we want to use
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b")
# this downloads the model (or loads it from disk cache)
model = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b")

Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00, 16.86it/s]


In [3]:
# we want to study what's happening in the model when we run some input text through it
input_text = "hello, Yoda my name is"
# the first step is to tokenize the input text
input_ids = tokenizer(input_text, return_tensors="pt", add_special_tokens=True)

In [4]:
input_ids['input_ids']

tensor([[     2,  17534, 235269, 146433,    970,   1503,    603]])

In [117]:
def print_input(input_ids):
  for i, t in enumerate(input_ids['input_ids'][0]):
    print(f"{i} Token {t}: \t\t{tokenizer.decode(t)}")

print_input(input_ids)

0 Token 2: 		<bos>
1 Token 17534: 		hello
2 Token 235269: 		,
3 Token 146433: 		 Yoda
4 Token 970: 		 my
5 Token 1503: 		 name
6 Token 603: 		 is


In [6]:
# now we can run our model. Let's generate just 2 more tokens for now
outputs = model.generate(**input_ids, max_new_tokens=2)
print("output tokens", len(outputs[0]))
outputs[0]


The 'max_batch_size' argument of HybridCache is deprecated and will be removed in v4.46. Use the more precisely named 'batch_size' argument instead.
Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)


output tokens 9


tensor([     2,  17534, 235269, 146433,    970,   1503,    603, 146433,    578])

In [7]:
# we need to turn the output tokens back into text (the output includes the input)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
generated_text

'hello, Yoda my name is Yoda and'

In [8]:
# To get hidden states for the input, we need to run the model again with the generated sequence
hidden_output = model(**input_ids, output_hidden_states=True)

# Access hidden states
hidden_states = hidden_output.hidden_states


In [9]:

print("input tokens", len(input_ids['input_ids'][0]))
print("hidden states (layers)", len(hidden_states))
print("hidden state shape", hidden_states[20].shape)

input tokens 7
hidden states (layers) 27
hidden state shape torch.Size([1, 7, 2304])


In [10]:
print(model)

Gemma2ForCausalLM(
  (model): Gemma2Model(
    (embed_tokens): Embedding(256000, 2304, padding_idx=0)
    (layers): ModuleList(
      (0-25): 26 x Gemma2DecoderLayer(
        (self_attn): Gemma2Attention(
          (q_proj): Linear(in_features=2304, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (v_proj): Linear(in_features=2304, out_features=1024, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2304, bias=False)
          (rotary_emb): Gemma2RotaryEmbedding()
        )
        (mlp): Gemma2MLP(
          (gate_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (up_proj): Linear(in_features=2304, out_features=9216, bias=False)
          (down_proj): Linear(in_features=9216, out_features=2304, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): Gemma2RMSNorm((2304,), eps=1e-06)
        (pre_feedforward_layernorm): Gemma2RMSNorm((2304,), eps

In [11]:
print(f"Number of layers in config: {model.config.num_hidden_layers}")
print(f"Hidden size: {model.config.hidden_size}")

Number of layers in config: 26
Hidden size: 2304


In [12]:
def gather_residual_activations(model, target_layer, inputs):
  """
  This function allows us to gather activations for a specific layer on a model.
  
  Args:
  - model: The model from which we want to gather activations.
  - target_layer: The specific layer index for which we want to gather activations.
  - inputs: The input data to be passed through the model.
  
  Returns:
  - target_act: The activations of the specified layer.
  """
  target_act = None
  def gather_target_act_hook(mod, inputs, outputs):
    nonlocal target_act # make sure we can modify the target_act from the outer scope
    target_act = outputs[0]
    return outputs
  # we could also easily target the MLP layer
  # handle = model.model.layers[target_layer].mlp.register_forward_hook(gather_mlp_output_hook)
  handle = model.model.layers[target_layer].register_forward_hook(gather_target_act_hook)
  _ = model.forward(inputs)
  handle.remove()
  return target_act

In [13]:
# the 20th index is actually the 21st layer
target_act = gather_residual_activations(model, 20, input_ids['input_ids'])

In [14]:
target_act.shape

torch.Size([1, 7, 2304])

In [15]:
target_act

tensor([[[ 1.9440,  1.7632, -2.0879,  ...,  1.6978, -2.0868, -0.0178],
         [-7.1130,  2.3935, -0.3611,  ..., -0.0207,  4.9698,  2.0012],
         [-8.1484,  0.2589, -0.5944,  ..., -1.2177,  2.7641,  2.0122],
         ...,
         [ 1.5289, -3.7196,  7.7939,  ..., -9.8385,  0.1159,  0.6954],
         [-3.1761,  1.1008,  0.5456,  ..., -1.7524, -2.0917,  2.4893],
         [ 0.9512, -2.4978, -0.4893,  ..., -4.7208, -5.6637, -0.7142]]])

In [16]:
hidden_states[21]

tensor([[[ 1.9440,  1.7632, -2.0879,  ...,  1.6978, -2.0868, -0.0178],
         [-7.1130,  2.3935, -0.3611,  ..., -0.0207,  4.9698,  2.0012],
         [-8.1484,  0.2589, -0.5944,  ..., -1.2177,  2.7641,  2.0122],
         ...,
         [ 1.5289, -3.7196,  7.7939,  ..., -9.8385,  0.1159,  0.6954],
         [-3.1761,  1.1008,  0.5456,  ..., -1.7524, -2.0917,  2.4893],
         [ 0.9512, -2.4978, -0.4893,  ..., -4.7208, -5.6637, -0.7142]]])

## SAE Time

In [17]:
from huggingface_hub import hf_hub_download

In [18]:
# We download the weights for the SAE we want to use
# https://www.neuronpedia.org/gemma-2-2b/20-gemmascope-res-16k
path_to_params = hf_hub_download(
    repo_id="google/gemma-scope-2b-pt-res",
    filename="layer_20/width_16k/average_l0_71/params.npz",
    force_download=False,
)
params = np.load(path_to_params)
pt_params = {k: torch.from_numpy(v) for k, v in params.items()}

In [19]:
import torch.nn as nn
class JumpReLUSAE(nn.Module):
  def __init__(self, d_model, d_sae):
    # Note that we initialise these to zeros because we're loading in pre-trained weights.
    # If you want to train your own SAEs then we recommend using blah
    super().__init__()
    self.W_enc = nn.Parameter(torch.zeros(d_model, d_sae))
    self.W_dec = nn.Parameter(torch.zeros(d_sae, d_model))
    self.threshold = nn.Parameter(torch.zeros(d_sae))
    self.b_enc = nn.Parameter(torch.zeros(d_sae))
    self.b_dec = nn.Parameter(torch.zeros(d_model))

  def encode(self, input_acts):
    pre_acts = input_acts @ self.W_enc + self.b_enc
    mask = (pre_acts > self.threshold)
    acts = mask * torch.nn.functional.relu(pre_acts)
    return acts

  def decode(self, acts):
    return acts @ self.W_dec + self.b_dec

  def forward(self, acts):
    acts = self.encode(acts)
    recon = self.decode(acts)
    return recon

In [20]:
params['W_dec'].shape[0]

16384

In [23]:
sae = JumpReLUSAE(params['W_enc'].shape[0], params['W_enc'].shape[1])
sae.load_state_dict(pt_params)

<All keys matched successfully>

In [24]:
sae_acts = sae.encode(target_act.to(torch.float32))
reconstruction = sae.decode(sae_acts)

In [25]:
sae_acts.shape

torch.Size([1, 7, 16384])

In [26]:
reconstruction.shape

torch.Size([1, 7, 2304])

In [27]:
values, inds = sae_acts.max(-1)

inds

tensor([[ 6631, 14956, 10299, 15449, 11302,  8564, 15449]])

In [28]:
values

tensor([[2028.7983,  122.3900,  107.6448,   90.2571,   89.6321,  102.8196,
           42.7972]])

In [29]:
k = 5  # Change this to the number of top values you want
values, indices = torch.topk(sae_acts, k, dim=-1)

print(f"Top {k} values shape: {values.shape}")
print(f"Top {k} indices shape: {indices.shape}")

# Print the top k values and indices for the first sequence item
print(f"\nTop {k} values for first sequence item:")
print(values[0, 0])
print(f"\nTop {k} indices for first sequence item:")
print(indices[0, 0])

Top 5 values shape: torch.Size([1, 7, 5])
Top 5 indices shape: torch.Size([1, 7, 5])

Top 5 values for first sequence item:
tensor([2028.7983,  781.3959,  534.8594,  264.1917,  252.5279])

Top 5 indices for first sequence item:
tensor([ 6631,   743,  5052, 16057,  9479])


In [30]:
from IPython.display import IFrame
html_template = "https://neuronpedia.org/{}/{}/{}?embed=true&embedexplanation=true&embedplots=true&embedtest=true&height=300"

def get_dashboard_html(sae_release = "gemma-2-2b", sae_id="20-gemmascope-res-16k", feature_idx=0):
    return html_template.format(sae_release, sae_id, feature_idx)



In [97]:
sae_release = "gemma-2-2b"
layer = 20
sae_id = f"{layer}-gemmascope-res-16k"
feature_idx = 15449

In [98]:
html = get_dashboard_html(sae_release = sae_release, sae_id=sae_id, feature_idx=feature_idx)
IFrame(html, width=1200, height=600)

In [247]:
def get_steering_vector(sae, feature_index, scale=500, normalize=True):
  sae_acts = torch.zeros(1, sae.W_dec.shape[0], device=sae.W_dec.device)
  sae_acts[0, feature_index] = scale
  steering = sae.decode(sae_acts)
  if normalize:
    steering = steering / torch.norm(steering, dim=-1, keepdim=True)
  return steering

In [68]:

def modify_layer_activation(model, target_layer, input_ids, steering, scale, max_new_tokens):
    """
    Modify the activation of a specific feature in a given layer.
    
    Args:
    - model: The LLM model
    - target_layer: The index of the layer to modify
    - input_ids: The input token IDs
    - sae: The Sparse Autoencoder
    - feature_index: The index of the feature to modify
    - modification_value: The value to add to the feature's activation
    
    Returns:
    - modified_output: The model's output after modification
    """
    def capture_and_modify_hook(module, inputs, outputs):
        # Capture the original activation
        original_act = outputs[0].detach()        
        # Decode the modified activations
        modified_act = original_act + steering * scale
        # Return the modified activation
        return (modified_act,) + outputs[1:]

    # Register the hook
    handle = model.model.layers[target_layer].register_forward_hook(capture_and_modify_hook)
    
    # Run the model with the modified activation
    with torch.no_grad():
        modified_output = model.generate(input_ids, max_new_tokens=max_new_tokens)
    
    # Remove the hook
    handle.remove()
    
    return modified_output


In [69]:
# we want to study what's happening in the model when we run some input text through it
input_text2 = "Hello, my name is"
# the first step is to tokenize the input text
input_ids2 = tokenizer(input_text2, return_tensors="pt", add_special_tokens=True)
max_new_tokens = 20

In [70]:
outputs2 = model.generate(**input_ids2, max_new_tokens=max_new_tokens)
generated_text2 = tokenizer.decode(outputs2[0], skip_special_tokens=True)

In [71]:
target_layer = 20  # The layer you want to modify
feature_index = 15449  # The feature index you want to modify
modification_value = 500.0  # The value to add to the feature's activation

In [72]:
steering = get_steering_vector(sae, feature_index, modification_value)
steering.shape

torch.Size([1, 2304])

In [73]:
steering

tensor([[-0.0153,  0.0034, -0.0012,  ..., -0.0239, -0.0238, -0.0027]])

In [75]:
steering2 = get_steering_vector(sae, feature_index, 1)

In [76]:
steering2

tensor([[ 6.0734e-03, -1.6376e-03,  6.5223e-03,  ..., -6.3663e-03,
          1.2676e-02,  7.8327e-05]])

In [77]:
torch.norm(steering2).item()

1.0000001192092896

In [78]:
w = sae.W_dec[feature_index] * 500
w

tensor([ -9.4724,   2.1418,  -2.0042,  ..., -11.5200, -15.4272,  -1.4665])

In [79]:
b = sae.b_dec
b

Parameter containing:
tensor([ 1.2695, -0.3415,  1.3470,  ..., -1.2878,  2.6408,  0.0191],
       requires_grad=True)

In [80]:
(w + b) / torch.norm(w + b).item()

tensor([-0.0153,  0.0034, -0.0012,  ..., -0.0239, -0.0238, -0.0027])

In [81]:
# Calculate cosine similarity between steering and steering2
cosine_similarity = torch.nn.functional.cosine_similarity(steering, steering2, dim=1)

print(f"Cosine similarity between steering and steering2: {cosine_similarity.item():.6f}")


Cosine similarity between steering and steering2: 0.369560


In [90]:
strength = 400
modified_output = modify_layer_activation(model, target_layer, input_ids2['input_ids'], steering, strength, max_new_tokens)

# Generate text from the modified output
modified_text = tokenizer.decode(modified_output[0], skip_special_tokens=True)

print("Input text:", input_text2)
print("Original text:", generated_text2)
print("Modified text:", modified_text)

Input text: Hello, my name is
Original text: Hello, my name is Dr. David and I'm a professor of mathematics at the University of California, Davis. I
Modified text: Hello, my name is Luke and I'm a Jedi. I' Star Wars, I's a Star Wars fan


In [265]:
def get_breakdown(input_text, layer):
  # combination of weights section
  # we want to study what's happening in the model when we run some input text through it
  
  # the first step is to tokenize the input text
  input_ids = tokenizer(input_text, return_tensors="pt", add_special_tokens=True)
  max_new_tokens = 1
  outputs = model.generate(**input_ids, max_new_tokens=max_new_tokens)
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
  print("Input text:", input_text)
  print("Generated text:", generated_text)

  target_act = gather_residual_activations(model, layer, input_ids['input_ids'])
  sae_acts = sae.encode(target_act)

  return input_text, input_ids, outputs, generated_text, target_act, sae_acts



In [266]:
dog = get_breakdown("I took my dog to the vet", layer)

Input text: I took my dog to the vet
Generated text: I took my dog to the vet for


In [267]:
dog[0] # input text
dog[1] # input ids
dog[2] # outputs
dog[3] # generated text
dog[4] # target act
dog[5] # sae acts

tensor([[[ 0.0000,  0.0000, 22.9148,  ...,  0.0000,  0.0000, 23.6990],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]]])

In [268]:
print_input(dog[1])

0 Token 2: 		<bos>
1 Token 235285: 		I
2 Token 3895: 		 took
3 Token 970: 		 my
4 Token 5929: 		 dog
5 Token 577: 		 to
6 Token 573: 		 the
7 Token 13512: 		 vet


In [269]:
dog_token_idx = 4

In [270]:
dog[5][0][dog_token_idx]  


tensor([0., 0., 0.,  ..., 0., 0., 0.])

In [271]:
torch.topk(dog[5][0][dog_token_idx], k=32)

torch.return_types.topk(
values=tensor([87.1561, 73.1058, 62.1382, 55.6819, 42.4143, 41.5249, 32.7077, 31.0943,
        29.6321, 28.7857, 25.7395, 25.6164, 25.5423, 24.0560, 23.7811, 23.5026,
        22.5944, 22.0243, 21.5107, 20.5071, 20.4034, 19.6517, 19.4906, 19.4306,
        19.2856, 19.2031, 18.7600, 18.6525, 18.0070, 17.9126, 17.5653, 17.4477]),
indices=tensor([12082,  3717, 14838,  6631,  3940, 11640,  6299,  4949,  1692,  1596,
        12468,  9768,  1089,  3645,  8476,  3461,  4418,  9251,  2860,   743,
         8837,  9988,  4956, 15611,  4310,    49,  6381,  5052,  3567, 13130,
        12001, 11700]))

In [170]:
dog_features = torch.topk(dog[5][0][dog_token_idx], k=32).indices.tolist()

In [163]:
html = get_dashboard_html(sae_release=sae_release, sae_id=sae_id, feature_idx=dog_features[1])
IFrame(html, width=1200, height=600)

In [178]:
cat = get_breakdown("I took my cat to the vet", layer)
cat[0] # input text
cat[1] # input ids
cat[2] # outputs
cat[3] # generated text
cat[4] # target act
cat[5] # sae acts
print_input(cat[1])
cat_token_idx = 4
cat[5][0][cat_token_idx]
torch.topk(cat[5][0][cat_token_idx], k=32)


Input text: I took my cat to the vet
Generated text: I took my cat to the vet for
0 Token 2: 		<bos>
1 Token 235285: 		I
2 Token 3895: 		 took
3 Token 970: 		 my
4 Token 4401: 		 cat
5 Token 577: 		 to
6 Token 573: 		 the
7 Token 13512: 		 vet


torch.return_types.topk(
values=tensor([77.9781, 67.4630, 50.3103, 47.5962, 45.2936, 42.3410, 41.1088, 35.0701,
        33.2441, 32.8526, 32.3556, 29.8324, 26.8098, 26.6279, 25.8563, 23.5551,
        23.0960, 23.0918, 23.0880, 22.3735, 22.1612, 21.7123, 21.5312, 20.6601,
        20.0952, 19.9142, 19.8482, 19.8014, 19.0374, 18.9671, 18.4707, 18.3619]),
indices=tensor([ 3717,  6631,  6772,  3705, 11640,  3940, 12082,  1596,  1692,  6299,
         8837,   743,    49,  4949,  9251,  4418,  3645,  9237,  9768,  8476,
         9388, 12001, 11700, 12085, 13130,  2468,  2860, 11923,  4956, 14795,
         5883,  9921]))

In [195]:
# 0 - caregiving / pet relationship
# 2, 3 - cat
cat_features = torch.topk(cat[5][0][cat_token_idx], k=32).indices.tolist()
html = get_dashboard_html(sae_release=sae_release, sae_id=sae_id, feature_idx=cat_features[2])
IFrame(html, width=1200, height=600)


In [180]:
pet = get_breakdown("I took my pet to the vet", layer)
pet[0] # input text
pet[1] # input ids
pet[2] # outputs
pet[3] # generated text
pet[4] # target act
pet[5] # sae acts
print_input(pet[1])
pet_token_idx = 4
pet[5][0][pet_token_idx]
torch.topk(pet[5][0][pet_token_idx], k=32)



Input text: I took my pet to the vet
Generated text: I took my pet to the vet for
0 Token 2: 		<bos>
1 Token 235285: 		I
2 Token 3895: 		 took
3 Token 970: 		 my
4 Token 7327: 		 pet
5 Token 577: 		 to
6 Token 573: 		 the
7 Token 13512: 		 vet


torch.return_types.topk(
values=tensor([104.9154,  60.8711,  50.7887,  45.2049,  43.0934,  42.2413,  28.3095,
         25.3525,  24.8617,  24.5862,  24.2275,  23.3301,  21.9932,  21.6487,
         20.9348,  19.8659,  19.3494,  18.9593,  18.5117,  18.3310,  17.9691,
         17.5346,  17.4015,  17.1198,  16.9894,  16.8812,  16.6862,  16.6782,
         16.5374,  16.1531,  14.9676,  14.5475]),
indices=tensor([14465,  1089,  3717, 15175,  6631, 12082,  9768,  9988,  1692, 10601,
         3940,  3567, 11640,  6299,  3645, 15969,  7938,  1596,  8476,    49,
        15572,  2012,  3404,   410,  4949,  5426,  2048, 13757,  3063,  5883,
         6772,  6381]))

In [184]:
pet_features = torch.topk(pet[5][0][pet_token_idx], k=32).indices.tolist()
html = get_dashboard_html(sae_release=sae_release, sae_id=sae_id, feature_idx=pet_features[3])
IFrame(html, width=1200, height=600)

In [185]:
tiger = get_breakdown("I took my tiger to the vet", layer)
tiger[0] # input text
tiger[1] # input ids
tiger[2] # outputs
tiger[3] # generated text
tiger[4] # target act
tiger[5] # sae acts
print_input(tiger[1])
tiger_token_idx = 4
tiger[5][0][tiger_token_idx]
torch.topk(tiger[5][0][tiger_token_idx], k=32)



Input text: I took my tiger to the vet
Generated text: I took my tiger to the vet today
0 Token 2: 		<bos>
1 Token 235285: 		I
2 Token 3895: 		 took
3 Token 970: 		 my
4 Token 31469: 		 tiger
5 Token 577: 		 to
6 Token 573: 		 the
7 Token 13512: 		 vet


torch.return_types.topk(
values=tensor([75.6183, 70.5476, 55.1671, 32.2650, 31.3126, 31.0271, 30.8081, 29.7397,
        29.6808, 29.3874, 27.7411, 25.9877, 25.0921, 24.3694, 23.8438, 22.7950,
        22.7310, 21.8632, 21.3663, 19.6248, 18.9306, 18.7682, 18.4276, 18.4126,
        17.6487, 17.6335, 17.4958, 17.3650, 16.8088, 16.5935, 16.0777, 16.0739]),
indices=tensor([ 6631,  8837, 11674,  3404,  6772,  4502, 11203,  7600,  2312, 15175,
         1692,  9768,  3645,   743,  4591,  4503, 12084, 13353, 11038,  3717,
         6161,  3019,  8476, 12284, 14804,  5400,  1416, 10805,  3567,  9193,
         5299, 12599]))

In [191]:
tiger_features = torch.topk(tiger[5][0][tiger_token_idx], k=32).indices.tolist()
html = get_dashboard_html(sae_release=sae_release, sae_id=sae_id, feature_idx=tiger_features[4])
IFrame(html, width=1200, height=600)


In [264]:
# get_breakdown("I have a ", layer)

In [272]:
ihavecat = get_breakdown("I have a cat", layer)
ihavecat_token_idx = 4
print_input(ihavecat[1])
torch.topk(ihavecat[5][0][ihavecat_token_idx], k=32)

Input text: I have a cat
Generated text: I have a cat that
0 Token 2: 		<bos>
1 Token 235285: 		I
2 Token 791: 		 have
3 Token 476: 		 a
4 Token 4401: 		 cat


torch.return_types.topk(
values=tensor([103.3880,  73.5215,  70.8671,  70.1262,  47.2687,  43.5433,  42.4167,
         31.5071,  28.5573,  27.2919,  27.0627,  26.1791,  25.1344,  24.9358,
         23.7866,  23.0588,  21.4485,  18.9834,  18.1445,  17.3769,  16.7656,
         16.4360,  15.9879,  15.3771,  14.4296,  14.2035,  14.1157,  14.1082,
         13.8757,  13.5338,  13.1257,  12.9537]),
indices=tensor([ 3705,  6772,  6631,  8837, 14838, 14795, 15361,  9768, 11725,  3940,
        12675,  7600,  1692,  3019,  7083, 13757,  8476, 12082,   743, 15509,
         9988,  4197,  3645,  4039,  9388, 11104,  3174,  2468, 15572, 14859,
         4310,  3404]))

In [240]:
# html = get_dashboard_html(sae_release=sae_release, sae_id=sae_id, feature_idx=3717)
# IFrame(html, width=1200, height=600)

In [304]:
ihavecat_act = ihavecat[4][0][4]

In [305]:
ihavecat_act

tensor([ 1.8583, -2.8761, -7.3865,  ..., -1.6729, -3.3603, -4.7288])

In [306]:
encode = sae.encode(ihavecat_act)
encode

tensor([0., 0., 0.,  ..., 0., 0., 0.])

In [326]:
topk = torch.topk(encode, k=16)
topk

torch.return_types.topk(
values=tensor([103.3880,  73.5215,  70.8671,  70.1262,  47.2687,  43.5433,  42.4167,
         31.5071,  28.5573,  27.2919,  27.0627,  26.1791,  25.1344,  24.9358,
         23.7866,  23.0588]),
indices=tensor([ 3705,  6772,  6631,  8837, 14838, 14795, 15361,  9768, 11725,  3940,
        12675,  7600,  1692,  3019,  7083, 13757]))

In [354]:
steering = get_steering_vector(sae, 3705, 103, False) + get_steering_vector(sae, 6772, 73, False) + get_steering_vector(sae, 6631, 70, False)+ get_steering_vector(sae, 8837, 70, False) - 3 * sae.b_dec
steering

tensor([[ 5.8366,  0.2981,  0.9478,  ..., -3.1647,  1.7556,  1.5096]])

In [355]:
steering_tensor = get_steering_vector(sae, torch.tensor([3705, 6772, 6631, 8837]), torch.tensor([103., 73., 70., 70.]), False)
steering_tensor

tensor([[ 5.8366,  0.2981,  0.9478,  ..., -3.1647,  1.7556,  1.5096]])

In [356]:
steering_topk = get_steering_vector(sae, topk.indices, topk.values, False)
steering_topk

tensor([[ 4.0292, -0.6819, -1.5606,  ..., -0.9899, -1.2499, -3.0969]])

In [357]:
ihavecat_sae = ihavecat[5][0][4]
reconstruction = sae.decode(ihavecat_sae)
reconstruction


tensor([ 1.8106, -0.3835, -6.8812,  ..., -1.6010, -2.3316, -4.1605])

In [358]:
cosine_similarity = torch.nn.functional.cosine_similarity(reconstruction, ihavecat_act.unsqueeze(0))
print(f"Cosine similarity: {cosine_similarity.item():.4f}")

Cosine similarity: 0.9577


In [359]:
cosine_similarity = torch.nn.functional.cosine_similarity(steering_topk, ihavecat_act.unsqueeze(0))
print(f"Cosine similarity: {cosine_similarity.item():.4f}")

Cosine similarity: 0.9157


In [360]:
cosine_similarity = torch.nn.functional.cosine_similarity(steering_tensor, ihavecat_act.unsqueeze(0))
print(f"Cosine similarity: {cosine_similarity.item():.4f}")


Cosine similarity: 0.8293


In [361]:
cosine_similarity = torch.nn.functional.cosine_similarity(steering, ihavecat_act.unsqueeze(0))
print(f"Cosine similarity: {cosine_similarity.item():.4f}")


Cosine similarity: 0.8293


In [362]:
# Calculate cosine similarity between ihavecat_act and steering
cosine_similarity = torch.nn.functional.cosine_similarity(steering_tensor, steering)
print(f"Cosine similarity: {cosine_similarity.item():.4f}")

Cosine similarity: 1.0000


In [197]:
steering

tensor([[ 0.0165, -0.0119,  0.0168,  ..., -0.0133,  0.0365, -0.0012]])

In [None]:
# steer("I have a ", steering, layer)