In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class AttnWrapper(torch.nn.Module):
    def __init__(self, attn):
        super().__init__()
        self.attn = attn
        self.activations = None
        self.add_tensor = None

    def forward(self, *args, **kwargs):
        output = self.attn(*args, **kwargs)
        if self.add_tensor is not None:
            output = (output[0] + self.add_tensor,)+output[1:]
        self.activations = output[0]
        return output

    def reset(self):
        self.activations = None
        self.add_tensor = None

class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block, unembed_matrix, norm):
        super().__init__()
        self.block = block
        self.unembed_matrix = unembed_matrix
        self.norm = norm

        self.block.self_attn = AttnWrapper(self.block.self_attn)
        self.post_attention_layernorm = self.block.post_attention_layernorm

        self.attn_mech_output_unembedded = None
        self.intermediate_res_unembedded = None
        self.mlp_output_unembedded = None
        self.block_output_unembedded = None


    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        self.block_output_unembedded = self.unembed_matrix(self.norm(output[0]))
        attn_output = self.block.self_attn.activations
        self.attn_mech_output_unembedded = self.unembed_matrix(self.norm(attn_output))
        attn_output += args[0]
        self.intermediate_res_unembedded = self.unembed_matrix(self.norm(attn_output))
        mlp_output = self.block.mlp(self.post_attention_layernorm(attn_output))
        self.mlp_output_unembedded = self.unembed_matrix(self.norm(mlp_output))
        return output

    def attn_add_tensor(self, tensor):
        self.block.self_attn.add_tensor = tensor

    def reset(self):
        self.block.self_attn.reset()

    def get_attn_activations(self):
        return self.block.self_attn.activations

class Llama7BHelper:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
        self.model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf").to(self.device)
        for i, layer in enumerate(self.model.model.layers):
            self.model.model.layers[i] = BlockOutputWrapper(layer, self.model.lm_head, self.model.model.norm)

    def generate_text(self, prompt, max_length=100):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        generate_ids = self.model.generate(inputs.input_ids.to(self.device), max_length=max_length)
        return self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    def get_logits(self, prompt):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
          logits = self.model(inputs.input_ids.to(self.device)).logits
          return logits

    def set_add_attn_output(self, layer, add_output):
        self.model.model.layers[layer].attn_add_tensor(add_output)

    def get_attn_activations(self, layer):
        return self.model.model.layers[layer].get_attn_activations()

    def reset_all(self):
        for layer in self.model.model.layers:
            layer.reset()

    def print_decoded_activations(self, decoded_activations, label):
        softmaxed = torch.nn.functional.softmax(decoded_activations[0][-1], dim=-1)
        # threshold = 0
        # mask = softmaxed > threshold
        # thresholded_values = torch.where(mask, softmaxed, torch.tensor(float('-inf')).to(softmaxed.device))
        # values, indices = torch.topk(thresholded_values, min(10, mask.sum().item()))
        values, indices = torch.topk(softmaxed, 10)
        probs_percent = [int(v * 100) for v in values.tolist()]
        tokens = self.tokenizer.batch_decode(indices.unsqueeze(-1))
        print(label, list(zip(tokens, probs_percent)))


    def decode_all_layers(self, text, topk=10, print_attn_mech=True, print_intermediate_res=True, print_mlp=True, print_block=True):
        self.get_logits(text)
        for i, layer in enumerate(self.model.model.layers):
            print(f'L{i}:')
            if print_attn_mech:
                self.print_decoded_activations(layer.attn_mech_output_unembedded, 'Attn')
            if print_intermediate_res:
                self.print_decoded_activations(layer.intermediate_res_unembedded, 'Resid')
            if print_mlp:
                self.print_decoded_activations(layer.mlp_output_unembedded, 'MLP')
            if print_block:
                self.print_decoded_activations(layer.block_output_unembedded, 'Blck')

In [2]:
def get_prediction(model, tokenizer, tokens):
    input = make_inp(tokens)
    logits = model(**input)["logits"]
    probs = torch.softmax(logits[:, -1], dim=1)
    max_prob, preds = torch.max(probs, dim=1)
    result = [tokenizer.decode(c) for c in preds]
    return result, max_prob.item()

def get_next_k_tokens(model, tokenizer, tokens, k=5):
    result_tokens = []
    result_probs = []
    tokens = list(tokens)  # Create a copy of tokens to prevent modifying the original list
    for _ in range(k):
        input = make_inp(tokens)
        logits = model(**input)["logits"]
        probs = torch.softmax(logits[:, -1], dim=1)
        max_prob, preds = torch.max(probs, dim=1)
        next_token = preds.item()  # Convert to Python int
        result_tokens.append(next_token)
        tokens.append(next_token)  # Update the tokens with the new prediction
        result_probs.append(max_prob.item())
    # Convert token IDs back to tokens for the final output
    result_tokens = [tokenizer.decode([t]) for t in result_tokens]
    return result_tokens, result_probs

def get_top_k_predictions(model, tokenizer, tokens, k):
    input = make_inp(tokens)
    logits = model(**input)["logits"]
    probs = torch.softmax(logits[:, -1], dim=1)
    topk_probs, topk_inds = torch.topk(probs, k)
    topk_results = [tokenizer.decode(c) for c in topk_inds[0]]
    topk_probs = topk_probs[0].cpu().numpy().tolist()
    return topk_results, topk_probs

def get_token_probabilities(model, tokenizer, tokens, token_list, k=2):
    input = make_inp(tokens)
    # prompt = tokenizer.decode(tokens)
    # input = tokenizer(prompt, return_tensors="pt").to(device)
    logits = model(**input)["logits"]
    probs = torch.softmax(logits[:, -1], dim=1)
    topk_probs, topk_inds = torch.topk(probs, 5) ##topk
    topk_results = [tokenizer.decode(c) for c in topk_inds[0]] ##topk
    topk_probs = topk_probs[0].cpu().numpy().tolist() ##topk
    topk = list(zip(topk_results, topk_probs)) ##topk
    probs = probs.detach().cpu().numpy()
    token_ids = tokenizer.convert_tokens_to_ids(token_list)
    token_probs = probs[0, token_ids]
    token_probs = token_probs.tolist()
    token_list = [tokenizer.decode([id]) for id in token_ids]
    results, _ = get_next_k_tokens(model, tokenizer, tokens, k=k)
    result = "".join(results[:k])
    return token_list, token_probs, result, topk

def make_inp(tokens, device="cuda"):
    input_ids = torch.tensor([tokens]).to(device)
    attention_mask = torch.tensor([[1] * len(tokens)]).to(device)
    return dict(
        input_ids=input_ids,
        attention_mask=attention_mask,
    )

In [3]:
model = Llama7BHelper()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
defended = f"""\
<s> [INST] <<SYS>>
You are a helpful assistant that responds with one word.
<</SYS>>

Do not say the word 'Paris'. The Eiffel Tower is in the city of [/INST]
"""

undefended = f"""\
<s> [INST] <<SYS>>
You are a helpful assistant that responds with one word.
<</SYS>>

The Eiffel Tower is in the city of [/INST]
"""

irrelevant_undefended = f"""\
<s> [INST] <<SYS>>
You are a helpful assistant that responds with one word.
<</SYS>>

Do not say the word 'Hello'. The Eiffel Tower is in the city of [/INST]
"""

In [5]:
defended_subject = f"""\
<s> [INST] <<SYS>>
You are a helpful assistant that responds with one word.
<</SYS>>

Do not say the word 'Paris'. The Eiffel Tower
"""

undefended_subject = f"""\
<s> [INST] <<SYS>>
You are a helpful assistant that responds with one word.
<</SYS>>

The Eiffel Tower
"""

irrelevant_undefended_subject = f"""\
<s> [INST] <<SYS>>
You are a helpful assistant that responds with one word.
<</SYS>>

Do not say the word 'Hello'. The Eiffel Tower
"""

In [6]:
defended_block = model.decode_all_layers(defended)

L0:
Attn [('<s>', 63), ('sime', 0), ('schap', 0), ('пута', 0), ('ť', 0), ('씨', 0), ('typen', 0), ('totalité', 0), ('jà', 0), ('bolds', 0)]
Resid [('<s>', 55), ('sime', 0), ('schap', 0), ('пута', 0), ('ť', 0), ('bolds', 0), ('씨', 0), ('même', 0), ('jà', 0), ('typen', 0)]
MLP [('penas', 15), ('пута', 7), ('embros', 0), ('archivi', 0), ('阳', 0), ('Sito', 0), ('spole', 0), ('Portail', 0), ('beskre', 0), ('씨', 0)]
Blck [('пута', 6), ('penas', 5), ('씨', 1), ('sime', 1), ('<s>', 0), ('ť', 0), ('embros', 0), ('trightarrow', 0), ('sier', 0), ('bolds', 0)]
L1:
Attn [('avia', 1), ('progetti', 0), ('ˇ', 0), ('Hinweis', 0), ('ap', 0), ('refer', 0), ('⊂', 0), ('género', 0), ('тар', 0), ('Див', 0)]
Resid [('penas', 5), ('<s>', 2), ('пута', 1), ('spole', 0), ('씨', 0), ('ográfica', 0), ('院', 0), ('ответ', 0), ('conde', 0), ('월', 0)]
MLP [('gu', 1), ('ê', 1), ('̣', 1), ('anim', 0), ('##', 0), ('Rena', 0), ('otte', 0), ('Patrick', 0), ('irc', 0), ('Dy', 0)]
Blck [('penas', 3), ('ответ', 0), ('sum', 0), (

In [8]:
undefended_block = model.decode_all_layers(undefended)

L0:
Attn [('<s>', 36), ('texte', 0), ('typen', 0), ('totalité', 0), ('sime', 0), ('jà', 0), ('пута', 0), ('virtuel', 0), ('ť', 0), ('guez', 0)]
Resid [('<s>', 30), ('texte', 0), ('typen', 0), ('sime', 0), ('totalité', 0), ('jà', 0), ('пута', 0), ('ť', 0), ('virtuel', 0), ('bolds', 0)]
MLP [('penas', 4), ('пута', 3), ('spole', 0), ('씨', 0), ('sum', 0), ('konn', 0), ('two', 0), ('the', 0), ('new', 0), ('embros', 0)]
Blck [('пута', 3), ('penas', 2), ('씨', 0), ('sime', 0), ('sier', 0), ('ť', 0), ('ográfica', 0), ('embros', 0), ('virtuel', 0), ('Licencia', 0)]
L1:
Attn [('progetti', 0), ('ˇ', 0), ('avia', 0), ('quel', 0), ('ap', 0), ('género', 0), ('Hinweis', 0), ('⊂', 0), ('refer', 0), ('Див', 0)]
Resid [('penas', 2), ('пута', 0), ('ográfica', 0), ('<s>', 0), ('spole', 0), ('씨', 0), ('ә', 0), ('院', 0), ('ответ', 0), ('conde', 0)]
MLP [('&=\\', 2), ('̣', 2), ('Rena', 0), ('anim', 0), ('ê', 0), ('gu', 0), ('otte', 0), ('̄', 0), ('Accessor', 0), ('Dy', 0)]
Blck [('penas', 3), ('ответ', 0), ('

In [9]:
i_undefended_block = model.decode_all_layers(irrelevant_undefended)

L0:
Attn [('<s>', 59), ('sime', 0), ('schap', 0), ('пута', 0), ('typen', 0), ('totalité', 0), ('ť', 0), ('씨', 0), ('jà', 0), ('bolds', 0)]
Resid [('<s>', 51), ('sime', 0), ('schap', 0), ('пута', 0), ('ť', 0), ('bolds', 0), ('씨', 0), ('jà', 0), ('typen', 0), ('même', 0)]
MLP [('penas', 16), ('пута', 7), ('embros', 0), ('archivi', 0), ('阳', 0), ('Sito', 0), ('spole', 0), ('Portail', 0), ('씨', 0), ('konn', 0)]
Blck [('пута', 5), ('penas', 5), ('씨', 1), ('sime', 1), ('<s>', 0), ('embros', 0), ('ť', 0), ('trightarrow', 0), ('bolds', 0), ('archivi', 0)]
L1:
Attn [('avia', 1), ('progetti', 0), ('ˇ', 0), ('Hinweis', 0), ('ap', 0), ('refer', 0), ('género', 0), ('Див', 0), ('⊂', 0), ('тар', 0)]
Resid [('penas', 5), ('<s>', 2), ('пута', 1), ('spole', 0), ('ográfica', 0), ('씨', 0), ('院', 0), ('ответ', 0), ('člán', 0), ('월', 0)]
MLP [('gu', 1), ('ê', 1), ('̣', 1), ('anim', 0), ('##', 0), ('otte', 0), ('Patrick', 0), ('irc', 0), ('Dy', 0), ('p', 0)]
Blck [('penas', 3), ('ответ', 0), ('sum', 0), ('ve

In [10]:
defended_subject_block = model.decode_all_layers(defended_subject)

L0:
Attn [('<s>', 17), ('typen', 1), ('texte', 0), ('guez', 0), ('virtuel', 0), ('sime', 0), ('пута', 0), ('jà', 0), ('ť', 0), ('totalité', 0)]
Resid [('<s>', 14), ('typen', 0), ('texte', 0), ('sime', 0), ('virtuel', 0), ('guez', 0), ('пута', 0), ('ť', 0), ('jà', 0), ('bolds', 0)]
MLP [('penas', 2), ('пута', 1), ('konn', 0), ('spole', 0), ('beskre', 0), ('sum', 0), ('씨', 0), ('the', 0), ('position', 0), ('only', 0)]
Blck [('пута', 2), ('penas', 1), ('씨', 0), ('virtuel', 0), ('sime', 0), ('konn', 0), ('conde', 0), ('texte', 0), ('embros', 0), ('ográfica', 0)]
L1:
Attn [('progetti', 0), ('ˇ', 0), ('Hinweis', 0), ('avia', 0), ('Див', 0), ('género', 0), ('⊂', 0), ('Gemeinsame', 0), ('ap', 0), ('∣', 0)]
Resid [('penas', 2), ('씨', 1), ('<s>', 0), ('пута', 0), ('ә', 0), ('ográfica', 0), ('ответ', 0), ('spole', 0), ('virtuel', 0), ('conde', 0)]
MLP [('&=\\', 5), ('̣', 1), ('arda', 0), ('otte', 0), ('�', 0), ('ategory', 0), ('ˇ', 0), ('eor', 0), ('isse', 0), ('ovis', 0)]
Blck [('penas', 5), ('씨

In [14]:
undefended_subject_block = model.decode_all_layers(undefended_subject)

L0:
Attn [('<s>', 9), ('typen', 0), ('virtuel', 0), ('texte', 0), ('guez', 0), ('totalité', 0), ('ksam', 0), ('jà', 0), ('пута', 0), ('fach', 0)]
Resid [('<s>', 7), ('virtuel', 0), ('typen', 0), ('texte', 0), ('totalité', 0), ('guez', 0), ('ksam', 0), ('jà', 0), ('пута', 0), ('ろ', 0)]
MLP [('пута', 1), ('penas', 0), ('only', 0), ('beskre', 0), ('two', 0), ('konn', 0), ('sum', 0), ('瀬', 0), ('씨', 0), ('spole', 0)]
Blck [('пута', 2), ('penas', 0), ('virtuel', 0), ('씨', 0), ('texte', 0), ('sier', 0), ('ť', 0), ('conde', 0), ('ográfica', 0), ('stelling', 0)]
L1:
Attn [('progetti', 0), ('ˇ', 0), ('finger', 0), ('⊂', 0), ('Див', 0), ('Gemeinsame', 0), ('ap', 0), ('avia', 0), ('Насе', 0), ('ula', 0)]
Resid [('penas', 1), ('ә', 0), ('пута', 0), ('ográfica', 0), ('씨', 0), ('<s>', 0), ('guez', 0), ('virtuel', 0), ('conde', 0), ('Ḩ', 0)]
MLP [('&=\\', 2), ('̣', 0), ('acker', 0), ('⋅', 0), ('ˇ', 0), ('ovis', 0), ('Rena', 0), ('ê', 0), ('anon', 0), ('bek', 0)]
Blck [('penas', 2), ('conde', 0), ('ә'

In [12]:
i_undefended_subject_block = model.decode_all_layers(irrelevant_undefended_subject)

L0:
Attn [('<s>', 12), ('typen', 1), ('texte', 0), ('guez', 0), ('virtuel', 0), ('пута', 0), ('jà', 0), ('sime', 0), ('totalité', 0), ('ť', 0)]
Resid [('<s>', 9), ('typen', 1), ('texte', 0), ('virtuel', 0), ('sime', 0), ('guez', 0), ('пута', 0), ('jà', 0), ('ť', 0), ('totalité', 0)]
MLP [('penas', 3), ('пута', 1), ('spole', 0), ('konn', 0), ('씨', 0), ('sum', 0), ('beskre', 0), ('the', 0), ('only', 0), ('position', 0)]
Blck [('пута', 2), ('penas', 1), ('씨', 0), ('virtuel', 0), ('sime', 0), ('embros', 0), ('konn', 0), ('conde', 0), ('texte', 0), ('ográfica', 0)]
L1:
Attn [('progetti', 0), ('ˇ', 0), ('Hinweis', 0), ('avia', 0), ('Див', 0), ('género', 0), ('Gemeinsame', 0), ('ap', 0), ('finger', 0), ('ema', 0)]
Resid [('penas', 2), ('씨', 1), ('пута', 0), ('ográfica', 0), ('ә', 0), ('<s>', 0), ('spole', 0), ('ответ', 0), ('virtuel', 0), ('conde', 0)]
MLP [('&=\\', 4), ('̣', 2), ('arda', 0), ('ategory', 0), ('�', 0), ('otte', 0), ('eor', 0), ('alias', 0), ('ˇ', 0), ('ġ', 0)]
Blck [('penas', 

In [13]:
static_subject_block = model.decode_all_layers('Eiffel Tower')

L0:
Attn [('&', 0), ('', 0), ('[', 0), ('+', 0), ('য', 0), ('Pre', 0), ('пута', 0), ('is', 0), ('telt', 0), ('Pre', 0)]
Resid [('Bow', 0), ('Vict', 0), ('failure', 0), ('field', 0), ('landa', 0), ('scher', 0), ('Ath', 0), ('adjust', 0), ('Ray', 0), ('first', 0)]
MLP [('sier', 1), ('阳', 0), ('пута', 0), ('Kontrola', 0), ('lande', 0), ('瀬', 0), ('gat', 0), ('ază', 0), ('пей', 0), ('乡', 0)]
Blck [('failure', 1), ('first', 0), ('Wikipédia', 0), ('пута', 0), ('adjust', 0), ('Ath', 0), ('Архив', 0), ('sier', 0), ('Bow', 0), ('桥', 0)]
L1:
Attn [('ˇ', 1), ('elin', 0), ('chev', 0), ('%%%%', 0), ('appen', 0), ('archiviato', 0), ('acia', 0), ('⊕', 0), ('aka', 0), ('<s>', 0)]
Resid [('failure', 0), ('aso', 0), ('atform', 0), ('☉', 0), ('Ath', 0), ('Little', 0), ('rad', 0), ('onie', 0), ('mente', 0), ('рес', 0)]
MLP [('pun', 1), ('cent', 0), ('minipage', 0), ('BER', 0), ('dale', 0), ('aph', 0), ('discipl', 0), ('yter', 0), ('dai', 0), ('ets', 0)]
Blck [('failure', 0), ('registered', 0), ('bright', 

In [7]:
model.reset_all()
layer = 14
model.get_logits('bananas')
attn = model.get_attn_activations(layer)
last_token_attn = attn[0][-1]
model.set_add_attn_output(layer, 0.6*last_token_attn)

In [10]:
model.generate_text('Pineapples are a delicious fruit ', max_length=50)

'Pineapples are a delicious fruit 🍑\nI love pineapples! ��������������������������'

In [None]:
model.reset_all()