In [82]:
import torch as t
import torch.nn as nn
import einops
from jaxtyping import Int, Float
from typing import List, Optional, Tuple
from transformer_lens import HookedTransformer, ActivationCache, utils
from transformer_lens.hook_points import HookPoint
from torch import Tensor
from tqdm.auto import tqdm

In [8]:
t.set_grad_enabled(False)
model1 = HookedTransformer.from_pretrained("gpt2-small")
model1.eval()

Loaded pretrained model gpt2-small into HookedTransformer


HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [63]:
class ModelSteer:
    def __init__(self, model: HookedTransformer):
        self.coeff = None
        self.vec = None
        self.model = model
        
    def set_vec(self, vec: Float[Tensor, "n_tokens d_model"]):
        self.vec = vec
        
    def set_coeff(self, coeff):
        self.coeff = coeff
    
    def get_steer_vec(self, prompt_add: str, prompt_sub: str, layer: int) -> Float[Tensor, "n_tokens d_model"]:
        caches = [self.model.run_with_cache(seq)[1] for seq in self.pad_tokens(prompt_add, prompt_sub)]
        vecs = [cache["resid_pre", layer] for cache in caches]
        return vecs[0] - vecs[1]
        
 
    def pad_tokens(self, prompt_add: str, prompt_sub: str) -> tuple[str, str]:
        tokenlen = lambda prompt: self.model.to_tokens(prompt).shape[1]
        pad_right = lambda prompt, length: prompt + " " * (length - tokenlen(prompt))
        
        length = max(tokenlen(prompt_add), tokenlen(prompt_sub))
        return pad_right(prompt_add, length), pad_right(prompt_sub, length)
    
    def run_model_with_vec(self, prompt, layer: int):
        assert self.vec is not None and self.coeff is not None, "set_vec() and set_coeff() are required"
        self.model.reset_hooks()
        out = self.model.run_with_hooks(
            prompt,
            fwd_hooks=[(utils.get_act_name("resid_pre", layer), self.hook)]
        )
        self.model.reset_hooks()
        return out
        
    def hook(self, resid_pre: Float[Tensor, "batch seq d_model"], hook: HookPoint):
        expanded_vec = t.zeros_like(resid_pre)
        expanded_vec[:, :self.vec.shape[1], :] = self.vec
        return resid_pre + expanded_vec * self.coeff
        

In [88]:
def sample_with_temperature(logits: t.Tensor, temperature: float = 1.0) -> t.Tensor:
    if temperature == 0:
        return t.argmax(logits, dim=-1)
    probs = t.softmax(logits / temperature, dim=-1)
    return t.multinomial(probs, num_samples=1).squeeze(-1)

def sample_model(model_steer: ModelSteer, prompt: str, layer: int, n_samples: int = 5, max_tokens: int = 50, use_steering: bool = True, temperature: float = 0.7) -> List[str]:
    samples = []
    for _ in tqdm(range(n_samples), desc="Generating samples", leave=False):
        tokens = model_steer.model.to_tokens(prompt)
        for _ in tqdm(range(max_tokens), desc="Generating tokens", leave=False):
            if use_steering:
                output = model_steer.run_model_with_vec(tokens, layer)
            else:
                output = model_steer.model(tokens, return_type="logits")
            
            next_token = sample_with_temperature(output[0, -1, :], temperature)
            if next_token == model_steer.model.tokenizer.eos_token_id:
                break
            tokens = t.cat([tokens, next_token.unsqueeze(0).unsqueeze(0)], dim=1)
        
        generated_text = model_steer.model.tokenizer.decode(tokens[0])
        samples.append(generated_text)
    
    return samples

def compare_samples(model_steer: ModelSteer, prompt: str, layer: int, n_samples: int = 5, max_tokens: int = 50, temperature: float = 0.7) -> Tuple[List[str], List[str]]:
    print("Generating steered samples:")
    steered_samples = sample_model(model_steer, prompt, layer, n_samples, max_tokens, use_steering=True, temperature=temperature)
    print("\nGenerating non-steered samples:")
    non_steered_samples = sample_model(model_steer, prompt, layer, n_samples, max_tokens, use_steering=False, temperature=temperature)
    return steered_samples, non_steered_samples

def print_compared_samples(prompt: str, steered_samples: List[str], non_steered_samples: List[str]):
    print(f"\nPrompt: {prompt}\n")
    for i, (steered, non_steered) in enumerate(zip(steered_samples, non_steered_samples), 1):
        print(f"Sample {i}:")
        print(f"Steered: {steered}")
        print(f"Non-steered: {non_steered}")
        print(f"\n")


In [94]:
modelSteer = ModelSteer(model1)
diff_vec = modelSteer.get_steer_vec("I talk about weddings constantly", "I do not talk about weddings constantly", 5)
modelSteer.set_vec(diff_vec)
modelSteer.set_coeff(12)

In [95]:
prompt = "Anonymous 01/11/15 (Sun) 05:21:06 PM No. 81168 >>81169"
steered_samples, non_steered_samples = compare_samples(modelSteer, prompt, layer=6, n_samples=5, temperature=0.7, max_tokens=100)
print_compared_samples(prompt, steered_samples, non_steered_samples)

Generating steered samples:


Generating samples:   0%|          | 0/5 [00:00<?, ?it/s]

Generating tokens:   0%|          | 0/100 [00:00<?, ?it/s]

Generating tokens:   0%|          | 0/100 [00:00<?, ?it/s]

Generating tokens:   0%|          | 0/100 [00:00<?, ?it/s]

Generating tokens:   0%|          | 0/100 [00:00<?, ?it/s]

Generating tokens:   0%|          | 0/100 [00:00<?, ?it/s]


Generating non-steered samples:


Generating samples:   0%|          | 0/5 [00:00<?, ?it/s]

Generating tokens:   0%|          | 0/100 [00:00<?, ?it/s]

Generating tokens:   0%|          | 0/100 [00:00<?, ?it/s]

Generating tokens:   0%|          | 0/100 [00:00<?, ?it/s]

Generating tokens:   0%|          | 0/100 [00:00<?, ?it/s]

Generating tokens:   0%|          | 0/100 [00:00<?, ?it/s]


Prompt: Anonymous 01/11/15 (Sun) 05:21:06 PM No. 81168 >>81169

Sample 1:
Steered: <|endoftext|>Anonymous 01/11/15 (Sun) 05:21:06 PM No. 81168 >>81169

They're still not the talk that is going out on the blog, they're still not the biggest of the four. They're still not the most talk, but their talk about it. They're still, constantly discussing about it constantly... I'm very fond of the recent one in the point of fact. It's a beautiful idea. As one added one of the reasons for the family growing up, I'm thankful that my family is going to be married to people with information that is
Non-steered: <|endoftext|>Anonymous 01/11/15 (Sun) 05:21:06 PM No. 81168 >>81169

>>81162

And a lot of people are saying "Oh that's a fucking joke" then "What was that about?" Then they are all trying to make it seem like there is a valid reason for the fact that it's funny. And a lot of people are saying "Oh that's a fucking joke" then "What was that about?" Then they are all trying to make it seem li