In [9]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class AttnWrapper(torch.nn.Module):
    def __init__(self, attn):
        super().__init__()
        self.attn = attn
        self.activations = None
        self.add_tensor = None

    def forward(self, *args, **kwargs):
        output = self.attn(*args, **kwargs)
        if self.add_tensor is not None:
            output = (output[0] + self.add_tensor,)+output[1:]
        self.activations = output[0]
        return output

    def reset(self):
        self.activations = None
        self.add_tensor = None

class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block, unembed_matrix, norm):
        super().__init__()
        self.block = block
        self.unembed_matrix = unembed_matrix
        self.norm = norm

        self.block.self_attn = AttnWrapper(self.block.self_attn)
        self.post_attention_layernorm = self.block.post_attention_layernorm

        self.attn_mech_output_unembedded = None
        self.intermediate_res_unembedded = None
        self.mlp_output_unembedded = None
        self.block_output_unembedded = None


    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        self.block_output_unembedded = self.unembed_matrix(self.norm(output[0]))
        attn_output = self.block.self_attn.activations
        self.attn_mech_output_unembedded = self.unembed_matrix(self.norm(attn_output))
        attn_output += args[0]
        self.intermediate_res_unembedded = self.unembed_matrix(self.norm(attn_output))
        mlp_output = self.block.mlp(self.post_attention_layernorm(attn_output))
        self.mlp_output_unembedded = self.unembed_matrix(self.norm(mlp_output))
        return output

    def attn_add_tensor(self, tensor):
        self.block.self_attn.add_tensor = tensor

    def reset(self):
        self.block.self_attn.reset()

    def get_attn_activations(self):
        return self.block.self_attn.activations

class Llama7BHelper:
    def __init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
        self.model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf").to(self.device)
        for i, layer in enumerate(self.model.model.layers):
            self.model.model.layers[i] = BlockOutputWrapper(layer, self.model.lm_head, self.model.model.norm)

    def generate_text(self, prompt, max_length=100):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        generate_ids = self.model.generate(inputs.input_ids.to(self.device), max_length=max_length)
        return self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    def get_logits(self, prompt):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
          logits = self.model(inputs.input_ids.to(self.device)).logits
          return logits

    def set_add_attn_output(self, layer, add_output):
        self.model.model.layers[layer].attn_add_tensor(add_output)

    def get_attn_activations(self, layer):
        return self.model.model.layers[layer].get_attn_activations()

    def reset_all(self):
        for layer in self.model.model.layers:
            layer.reset()

    def print_decoded_activations(self, decoded_activations, label):
        softmaxed = torch.nn.functional.softmax(decoded_activations[0][-1], dim=-1)
        # threshold = 0
        # mask = softmaxed > threshold
        # thresholded_values = torch.where(mask, softmaxed, torch.tensor(float('-inf')).to(softmaxed.device))
        # values, indices = torch.topk(thresholded_values, min(10, mask.sum().item()))
        values, indices = torch.topk(softmaxed, 10)
        probs_percent = [int(v * 100) for v in values.tolist()]
        tokens = self.tokenizer.batch_decode(indices.unsqueeze(-1))
        print(label, list(zip(tokens, probs_percent)))


    def decode_all_layers(self, text, topk=10, print_attn_mech=True, print_intermediate_res=True, print_mlp=True, print_block=True):
        self.get_logits(text)
        for i, layer in enumerate(self.model.model.layers):
            print(f'L{i}:')
            if print_attn_mech:
                self.print_decoded_activations(layer.attn_mech_output_unembedded, 'Attn')
            if print_intermediate_res:
                self.print_decoded_activations(layer.intermediate_res_unembedded, 'Resid')
            if print_mlp:
                self.print_decoded_activations(layer.mlp_output_unembedded, 'MLP')
            if print_block:
                self.print_decoded_activations(layer.block_output_unembedded, 'Blck')
    
    def get_token_rank(self, decoded_activations, target_token):
        softmaxed = torch.nn.functional.softmax(decoded_activations[0][-1], dim=-1)
        # Sort the softmaxed probabilities in descending order
        sorted_probs, sorted_indices = torch.sort(softmaxed, descending=True)
        tokens = self.tokenizer.batch_decode(sorted_indices.unsqueeze(-1))
        
        # Find the index (rank) of the target token; add 1 because indexing starts from 0
        try:
            rank = tokens.index(target_token) + 1
        except ValueError:
            # If token is not in the list
            rank = -1 
        return rank

    def rank_all_layers(self, text, target_token, topk=10, print_attn_mech=True, print_intermediate_res=True, print_mlp=True, print_block=True):
        self.get_logits(text)
        attn_ranks = []
        resid_ranks = []
        mlp_ranks = []
        block_ranks = []
        
        for i, layer in enumerate(self.model.model.layers):
            if print_attn_mech:
                rank = self.get_token_rank(layer.attn_mech_output_unembedded, target_token)
                attn_ranks.append(rank)
            if print_intermediate_res:
                rank = self.get_token_rank(layer.intermediate_res_unembedded, target_token)
                resid_ranks.append(rank)
            if print_mlp:
                rank = self.get_token_rank(layer.mlp_output_unembedded, target_token)
                mlp_ranks.append(rank)
            if print_block:
                rank = self.get_token_rank(layer.block_output_unembedded, target_token)
                block_ranks.append(rank)
        
        return attn_ranks, resid_ranks, mlp_ranks, block_ranks

In [2]:
def get_prediction(model, tokenizer, tokens):
    input = make_inp(tokens)
    logits = model(**input)["logits"]
    probs = torch.softmax(logits[:, -1], dim=1)
    max_prob, preds = torch.max(probs, dim=1)
    result = [tokenizer.decode(c) for c in preds]
    return result, max_prob.item()

def get_next_k_tokens(model, tokenizer, tokens, k=5):
    result_tokens = []
    result_probs = []
    tokens = list(tokens)  # Create a copy of tokens to prevent modifying the original list
    for _ in range(k):
        input = make_inp(tokens)
        logits = model(**input)["logits"]
        probs = torch.softmax(logits[:, -1], dim=1)
        max_prob, preds = torch.max(probs, dim=1)
        next_token = preds.item()  # Convert to Python int
        result_tokens.append(next_token)
        tokens.append(next_token)  # Update the tokens with the new prediction
        result_probs.append(max_prob.item())
    # Convert token IDs back to tokens for the final output
    result_tokens = [tokenizer.decode([t]) for t in result_tokens]
    return result_tokens, result_probs

def get_top_k_predictions(model, tokenizer, tokens, k):
    input = make_inp(tokens)
    logits = model(**input)["logits"]
    probs = torch.softmax(logits[:, -1], dim=1)
    topk_probs, topk_inds = torch.topk(probs, k)
    topk_results = [tokenizer.decode(c) for c in topk_inds[0]]
    topk_probs = topk_probs[0].cpu().numpy().tolist()
    return topk_results, topk_probs

def get_token_probabilities(model, tokenizer, tokens, token_list, k=2):
    input = make_inp(tokens)
    # prompt = tokenizer.decode(tokens)
    # input = tokenizer(prompt, return_tensors="pt").to(device)
    logits = model(**input)["logits"]
    probs = torch.softmax(logits[:, -1], dim=1)
    topk_probs, topk_inds = torch.topk(probs, 5) ##topk
    topk_results = [tokenizer.decode(c) for c in topk_inds[0]] ##topk
    topk_probs = topk_probs[0].cpu().numpy().tolist() ##topk
    topk = list(zip(topk_results, topk_probs)) ##topk
    probs = probs.detach().cpu().numpy()
    token_ids = tokenizer.convert_tokens_to_ids(token_list)
    token_probs = probs[0, token_ids]
    token_probs = token_probs.tolist()
    token_list = [tokenizer.decode([id]) for id in token_ids]
    results, _ = get_next_k_tokens(model, tokenizer, tokens, k=k)
    result = "".join(results[:k])
    return token_list, token_probs, result, topk

def make_inp(tokens, device="cuda"):
    input_ids = torch.tensor([tokens]).to(device)
    attention_mask = torch.tensor([[1] * len(tokens)]).to(device)
    return dict(
        input_ids=input_ids,
        attention_mask=attention_mask,
    )

In [3]:
model = Llama7BHelper()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
defended = f"""\
<s> [INST] <<SYS>>
You are a helpful assistant that responds with one word.
<</SYS>>

Do not say the word 'Paris'. The Eiffel Tower is in the city of [/INST]
"""

undefended = f"""\
<s> [INST] <<SYS>>
You are a helpful assistant that responds with one word.
<</SYS>>

The Eiffel Tower is in the city of [/INST]
"""

irrelevant_undefended = f"""\
<s> [INST] <<SYS>>
You are a helpful assistant that responds with one word.
<</SYS>>

Do not say the word 'Hello'. The Eiffel Tower is in the city of [/INST]
"""

In [5]:
defended_subject = f"""\
<s> [INST] <<SYS>>
You are a helpful assistant that responds with one word.
<</SYS>>

Do not say the word 'Paris'. The Eiffel Tower
"""

undefended_subject = f"""\
<s> [INST] <<SYS>>
You are a helpful assistant that responds with one word.
<</SYS>>

The Eiffel Tower
"""

irrelevant_undefended_subject = f"""\
<s> [INST] <<SYS>>
You are a helpful assistant that responds with one word.
<</SYS>>

Do not say the word 'Hello'. The Eiffel Tower
"""

In [11]:
a, r, m, b = model.rank_all_layers(text=defended_subject, target_token="Par")

In [13]:
print(a)
print(r)
print(m)
print(b)

[19066, 17193, 25034, 13862, 5432, 1939, 6655, 1704, 2417, 6351, 8785, 13509, 17610, 20901, 13337, 27026, 12211, 5196, 3064, 9971, 107, 431, 10509, 1, 18, 5548, 5692, 31983, 256, 997, 6392, 11914]
[20183, 15497, 20143, 19185, 16618, 5851, 8077, 2374, 2372, 3309, 2550, 5941, 5379, 3674, 3242, 4606, 4523, 4647, 5046, 2456, 1768, 350, 421, 67, 15, 11, 14, 90, 47, 33, 32, 66]
[6978, 19899, 3908, 16502, 8176, 19830, 8245, 17324, 19461, 13235, 17790, 10473, 2149, 10200, 6533, 12020, 20354, 8827, 2579, 16435, 1881, 4267, 13123, 3001, 81, 21741, 2182, 1195, 314, 1743, 642, 374]
[14056, 17093, 16056, 19127, 12222, 9255, 5237, 3106, 3813, 2745, 3043, 4002, 2053, 2607, 1846, 4379, 5638, 5870, 2609, 3783, 834, 338, 491, 54, 8, 11, 10, 84, 28, 32, 66, 31]


In [14]:
a2, r2, m2, b2 = model.rank_all_layers(text=defended_subject, target_token="Lond")

In [15]:
print(a2)
print(r2)
print(m2)
print(b2)

[2870, 1635, 10664, 925, 10080, 14883, 8727, 1, 1854, 9213, 22369, 3527, 18147, 17377, 5635, 607, 1598, 872, 3650, 5547, 9406, 4051, 1971, 3355, 112, 1439, 298, 10291, 7347, 24164, 17570, 4301]
[3934, 635, 1590, 913, 1355, 2982, 724, 99, 16, 32, 1661, 552, 5860, 9425, 6038, 4786, 4172, 976, 2859, 3249, 2511, 2287, 847, 513, 97, 78, 37, 79, 162, 392, 449, 410]
[1677, 15692, 11206, 6572, 13226, 1104, 25734, 6170, 13832, 30514, 3367, 21902, 20285, 9198, 16462, 15715, 4633, 28047, 8511, 11344, 10312, 3830, 4409, 1029, 6109, 9713, 7607, 23610, 16517, 5149, 2262, 67]
[1522, 1409, 2314, 1307, 1808, 538, 2549, 36, 20, 488, 1090, 4929, 7079, 7028, 7111, 5196, 2227, 2100, 2135, 2875, 2207, 1533, 578, 290, 97, 69, 47, 154, 248, 372, 456, 99]


In [16]:
a3, r3, m3, b3 = model.rank_all_layers(text=undefended_subject, target_token="Par")

In [17]:
print(a3)
print(r3)
print(m3)
print(b3)

[19642, 13249, 19662, 11535, 8766, 6795, 16943, 1566, 1828, 6338, 11584, 7705, 7759, 20290, 7611, 10898, 4397, 9784, 1599, 4028, 949, 134, 18014, 5774, 94, 2940, 1846, 22094, 1, 3645, 5286, 15486]
[20663, 12944, 19078, 16146, 14664, 11787, 18888, 3479, 2563, 4769, 8153, 9978, 11421, 18381, 11006, 9788, 16750, 9278, 6047, 2964, 1665, 529, 147, 89, 14, 3, 3, 5, 1, 1, 1, 3]
[7108, 15528, 6338, 20894, 8559, 20646, 1366, 17711, 21013, 17374, 22866, 19737, 25167, 3366, 9009, 26978, 1243, 18293, 3510, 6214, 4511, 55, 1348, 13, 10, 12063, 16080, 175, 1473, 4453, 977, 98]
[14399, 12949, 15649, 18262, 12043, 16388, 7122, 4561, 3678, 4960, 11507, 13376, 15516, 12155, 8864, 16113, 8963, 10850, 3531, 2128, 1171, 105, 101, 18, 4, 3, 5, 2, 1, 1, 3, 3]


In [18]:
a4, r4, m4, b4 = model.rank_all_layers(text=undefended_subject, target_token="Lond")

In [19]:
print(a4)
print(r4)
print(m4)
print(b4)

[1530, 698, 8937, 694, 4764, 16921, 3837, 60, 906, 11032, 21190, 12871, 21615, 17548, 3992, 1138, 403, 5344, 5854, 8809, 11370, 1162, 415, 7798, 646, 434, 75, 10533, 18286, 9770, 19117, 5514]
[2100, 454, 1919, 1189, 1375, 3070, 501, 55, 4, 101, 473, 22, 156, 2750, 1394, 421, 836, 448, 1520, 2863, 3448, 793, 879, 556, 208, 413, 251, 683, 1154, 2799, 2790, 1253]
[3432, 27315, 19642, 15448, 9008, 2146, 13847, 4286, 19207, 22864, 251, 18497, 19940, 5784, 4465, 26157, 6368, 21463, 17519, 16771, 366, 28427, 1310, 2593, 25866, 19726, 13262, 17093, 13442, 7472, 2136, 28]
[1195, 2395, 3076, 2286, 1793, 895, 429, 17, 7, 188, 18, 41, 1213, 1569, 800, 1497, 598, 722, 2031, 3567, 1337, 1961, 685, 332, 428, 533, 526, 956, 1665, 2395, 1478, 100]


In [20]:
a5, r5, m5, b5 = model.rank_all_layers(text=defended, target_token="Par")

In [24]:
print(a5)
print(r5)
print(m5)
print(b5)

[16419, 12082, 19587, 5970, 24859, 4838, 3975, 2505, 3311, 1959, 12357, 9647, 27859, 12558, 9773, 13300, 23002, 15349, 31231, 1304, 288, 13, 17875, 30900, 23, 8386, 2634, 31991, 1, 317, 31953, 31995]
[17782, 9492, 6834, 5276, 8655, 5374, 5461, 1898, 817, 1007, 2673, 4399, 5836, 6623, 10788, 7147, 8289, 2457, 13024, 5945, 1249, 63, 4, 8, 5, 2, 1, 2, 1, 1, 2, 25]
[5410, 3771, 5617, 9830, 13416, 24368, 9597, 3099, 25312, 18400, 16991, 5313, 18819, 27688, 6682, 3752, 332, 10125, 6692, 11481, 417, 2, 173, 2, 1, 31993, 226, 20, 1751, 3853, 1153, 948]
[11092, 5585, 5063, 4497, 8312, 9915, 3340, 1086, 2014, 1255, 5182, 2468, 6675, 12938, 6766, 3860, 1665, 2459, 10639, 5434, 331, 5, 4, 4, 2, 2, 1, 1, 1, 1, 1, 16]


In [21]:
a6, r6, m6, b6 = model.rank_all_layers(text=defended, target_token="Lond")

In [25]:
print(a6)
print(r6)
print(m6)
print(b6)

[6890, 775, 7852, 9226, 3470, 9319, 6844, 828, 1322, 8832, 23571, 10056, 13657, 16306, 18633, 672, 2639, 12267, 2, 16441, 5883, 8981, 796, 7167, 203, 4625, 4, 30488, 6459, 16442, 465, 384]
[8390, 1108, 2389, 3416, 3494, 1278, 1340, 704, 1124, 1420, 10107, 3879, 7053, 4046, 11705, 6834, 6832, 5428, 199, 824, 486, 256, 610, 47, 12, 20, 16, 87, 57, 37, 17, 15]
[548, 17464, 14170, 25631, 2883, 8130, 20019, 22451, 19144, 27916, 2527, 20540, 6467, 4407, 14450, 15579, 7717, 417, 23450, 4550, 2219, 16626, 4, 6, 5020, 30917, 28547, 621, 911, 1677, 355, 1]
[3261, 2250, 3733, 6331, 1259, 1542, 1598, 1568, 1446, 5210, 4316, 7213, 3239, 3745, 12613, 7009, 4462, 4325, 564, 465, 246, 1003, 55, 8, 23, 47, 44, 59, 36, 30, 17, 1]


In [22]:
a7, r7, m7, b7 = model.rank_all_layers(text=undefended, target_token="Par")

In [26]:
print(a7)
print(r7)
print(m7)
print(b7)

[17342, 10882, 20067, 8491, 22859, 10084, 7801, 3017, 1207, 2059, 13026, 12562, 25851, 7564, 14717, 6694, 16695, 9954, 8880, 90, 26, 7, 22954, 1223, 35, 2429, 24, 14512, 1, 929, 3781, 63]
[18641, 9683, 10510, 8565, 11086, 8475, 6305, 1388, 455, 262, 539, 4756, 4685, 5307, 9396, 4224, 3744, 2067, 1555, 123, 28, 8, 6, 4, 5, 2, 2, 2, 2, 1, 1, 1]
[5784, 10145, 6543, 11217, 10119, 14020, 6200, 10565, 18451, 5562, 23854, 2519, 22354, 25214, 7640, 2043, 1592, 13726, 3483, 6656, 3451, 5, 249, 1, 1, 31843, 150, 2, 4, 5187, 98, 62]
[11753, 8360, 8415, 7624, 8914, 8393, 2689, 1122, 567, 88, 3563, 1784, 6681, 9324, 6128, 2731, 1763, 1835, 1195, 99, 15, 6, 4, 2, 2, 2, 2, 2, 1, 1, 1, 1]


In [23]:
a8, r8, m8, b8 = model.rank_all_layers(text=undefended, target_token="Lond")

In [27]:
print(a8)
print(r8)
print(m8)
print(b8)

[4665, 1218, 7339, 10783, 7612, 6796, 4082, 2459, 535, 8992, 23678, 14578, 22680, 13251, 20352, 1455, 3586, 8573, 126, 3946, 810, 136, 2802, 7296, 49, 3633, 72, 10539, 5530, 17425, 2116, 2249]
[5928, 1045, 2988, 3104, 4039, 994, 890, 452, 692, 1313, 14149, 9078, 16776, 10254, 14771, 12684, 15555, 15795, 4811, 2837, 907, 203, 408, 45, 14, 30, 43, 137, 263, 354, 228, 138]
[1480, 22729, 13586, 23880, 2155, 4643, 16689, 21266, 22160, 29663, 2716, 17775, 6054, 3864, 21437, 23462, 20071, 1211, 12402, 6356, 952, 12127, 20, 10, 13635, 31170, 30488, 14339, 8637, 4941, 550, 49]
[2081, 3175, 3103, 5161, 1312, 1014, 793, 1058, 1257, 8546, 8501, 13605, 10610, 6866, 17692, 16926, 17274, 11783, 4431, 2013, 544, 572, 52, 13, 31, 57, 152, 241, 297, 308, 131, 13]


In [10]:
defended_block = model.decode_all_layers(defended)

L0:
Attn [('<s>', 63), ('sime', 0), ('schap', 0), ('–ø—É—Ç–∞', 0), ('≈•', 0), ('Ïî®', 0), ('typen', 0), ('totalit√©', 0), ('j√†', 0), ('bolds', 0)]
Resid [('<s>', 55), ('sime', 0), ('schap', 0), ('–ø—É—Ç–∞', 0), ('≈•', 0), ('bolds', 0), ('Ïî®', 0), ('m√™me', 0), ('j√†', 0), ('typen', 0)]
MLP [('penas', 15), ('–ø—É—Ç–∞', 7), ('embros', 0), ('archivi', 0), ('Èò≥', 0), ('Sito', 0), ('spole', 0), ('Portail', 0), ('beskre', 0), ('Ïî®', 0)]
Blck [('–ø—É—Ç–∞', 6), ('penas', 5), ('Ïî®', 1), ('sime', 1), ('<s>', 0), ('≈•', 0), ('embros', 0), ('trightarrow', 0), ('sier', 0), ('bolds', 0)]
L1:
Attn [('avia', 1), ('progetti', 0), ('Àá', 0), ('Hinweis', 0), ('ap', 0), ('refer', 0), ('‚äÇ', 0), ('g√©nero', 0), ('—Ç–∞—Ä', 0), ('–î–∏–≤', 0)]
Resid [('penas', 5), ('<s>', 2), ('–ø—É—Ç–∞', 1), ('spole', 0), ('Ïî®', 0), ('ogr√°fica', 0), ('Èô¢', 0), ('–æ—Ç–≤–µ—Ç', 0), ('conde', 0), ('Ïõî', 0)]
MLP [('gu', 1), ('√™', 1), ('Ã£', 1), ('anim', 0), ('##', 0), ('Rena', 0), ('otte', 0), ('Patrick', 0), ('irc', 0

In [11]:
undefended_block = model.decode_all_layers(undefended, print_attn_mech=False, print_mlp=False, print_intermediate_res=False)

L0:
Blck [('–ø—É—Ç–∞', 3), ('penas', 2), ('Ïî®', 0), ('sime', 0), ('sier', 0), ('≈•', 0), ('ogr√°fica', 0), ('embros', 0), ('virtuel', 0), ('Licencia', 0)]
L1:
Blck [('penas', 3), ('–æ—Ç–≤–µ—Ç', 0), ('set', 0), ('sum', 0), ('conde', 0), ('Ïî®', 0), ('ve', 0), ('sub', 0), ('&=\\', 0), ('ogr√°fica', 0)]
L2:
Blck [('penas', 0), ('sum', 0), ('estaven', 0), ('–æ—Ç–≤–µ—Ç', 0), ('length', 0), ('properly', 0), ('conde', 0), ('sub', 0), ('ve', 0), ('set', 0)]
L3:
Blck [('estaven', 1), ('pros', 0), ('ve', 0), ('guez', 0), ('evol', 0), ('Wort', 0), ('endencia', 0), ('conde', 0), ('pl', 0), ('ogr√°fica', 0)]
L4:
Blck [('estaven', 1), ('directly', 0), ('ch', 0), ('pl', 0), ('Br', 0), ('Hat', 0), ('hat', 0), ('March', 0), ('Mack', 0), ('thumb', 0)]
L5:
Blck [('Br', 1), ('stories', 0), ('izada', 0), ('estaven', 0), ('Mack', 0), ('meant', 0), ('Be', 0), ('pure', 0), ('directly', 0), ('del', 0)]
L6:
Blck [('Br', 1), ('gresql', 0), ('Mack', 0), ('F', 0), ('fen', 0), ('–≤—É', 0), ('—Å—Ç–≤–∏–∏', 0), ('mom

In [9]:
i_undefended_block = model.decode_all_layers(irrelevant_undefended)

L0:
Attn [('<s>', 59), ('sime', 0), ('schap', 0), ('–ø—É—Ç–∞', 0), ('typen', 0), ('totalit√©', 0), ('≈•', 0), ('Ïî®', 0), ('j√†', 0), ('bolds', 0)]
Resid [('<s>', 51), ('sime', 0), ('schap', 0), ('–ø—É—Ç–∞', 0), ('≈•', 0), ('bolds', 0), ('Ïî®', 0), ('j√†', 0), ('typen', 0), ('m√™me', 0)]
MLP [('penas', 16), ('–ø—É—Ç–∞', 7), ('embros', 0), ('archivi', 0), ('Èò≥', 0), ('Sito', 0), ('spole', 0), ('Portail', 0), ('Ïî®', 0), ('konn', 0)]
Blck [('–ø—É—Ç–∞', 5), ('penas', 5), ('Ïî®', 1), ('sime', 1), ('<s>', 0), ('embros', 0), ('≈•', 0), ('trightarrow', 0), ('bolds', 0), ('archivi', 0)]
L1:
Attn [('avia', 1), ('progetti', 0), ('Àá', 0), ('Hinweis', 0), ('ap', 0), ('refer', 0), ('g√©nero', 0), ('–î–∏–≤', 0), ('‚äÇ', 0), ('—Ç–∞—Ä', 0)]
Resid [('penas', 5), ('<s>', 2), ('–ø—É—Ç–∞', 1), ('spole', 0), ('ogr√°fica', 0), ('Ïî®', 0), ('Èô¢', 0), ('–æ—Ç–≤–µ—Ç', 0), ('ƒçl√°n', 0), ('Ïõî', 0)]
MLP [('gu', 1), ('√™', 1), ('Ã£', 1), ('anim', 0), ('##', 0), ('otte', 0), ('Patrick', 0), ('irc', 0), ('Dy', 0

In [10]:
defended_subject_block = model.decode_all_layers(defended_subject)

L0:
Attn [('<s>', 17), ('typen', 1), ('texte', 0), ('guez', 0), ('virtuel', 0), ('sime', 0), ('–ø—É—Ç–∞', 0), ('j√†', 0), ('≈•', 0), ('totalit√©', 0)]
Resid [('<s>', 14), ('typen', 0), ('texte', 0), ('sime', 0), ('virtuel', 0), ('guez', 0), ('–ø—É—Ç–∞', 0), ('≈•', 0), ('j√†', 0), ('bolds', 0)]
MLP [('penas', 2), ('–ø—É—Ç–∞', 1), ('konn', 0), ('spole', 0), ('beskre', 0), ('sum', 0), ('Ïî®', 0), ('the', 0), ('position', 0), ('only', 0)]
Blck [('–ø—É—Ç–∞', 2), ('penas', 1), ('Ïî®', 0), ('virtuel', 0), ('sime', 0), ('konn', 0), ('conde', 0), ('texte', 0), ('embros', 0), ('ogr√°fica', 0)]
L1:
Attn [('progetti', 0), ('Àá', 0), ('Hinweis', 0), ('avia', 0), ('–î–∏–≤', 0), ('g√©nero', 0), ('‚äÇ', 0), ('Gemeinsame', 0), ('ap', 0), ('‚à£', 0)]
Resid [('penas', 2), ('Ïî®', 1), ('<s>', 0), ('–ø—É—Ç–∞', 0), ('”ô', 0), ('ogr√°fica', 0), ('–æ—Ç–≤–µ—Ç', 0), ('spole', 0), ('virtuel', 0), ('conde', 0)]
MLP [('&=\\', 5), ('Ã£', 1), ('arda', 0), ('otte', 0), ('ÔøΩ', 0), ('ategory', 0), ('Àá', 0), ('eor', 0

In [14]:
undefended_subject_block = model.decode_all_layers(undefended_subject)

L0:
Attn [('<s>', 9), ('typen', 0), ('virtuel', 0), ('texte', 0), ('guez', 0), ('totalit√©', 0), ('ksam', 0), ('j√†', 0), ('–ø—É—Ç–∞', 0), ('fach', 0)]
Resid [('<s>', 7), ('virtuel', 0), ('typen', 0), ('texte', 0), ('totalit√©', 0), ('guez', 0), ('ksam', 0), ('j√†', 0), ('–ø—É—Ç–∞', 0), ('„Çç', 0)]
MLP [('–ø—É—Ç–∞', 1), ('penas', 0), ('only', 0), ('beskre', 0), ('two', 0), ('konn', 0), ('sum', 0), ('ÁÄ¨', 0), ('Ïî®', 0), ('spole', 0)]
Blck [('–ø—É—Ç–∞', 2), ('penas', 0), ('virtuel', 0), ('Ïî®', 0), ('texte', 0), ('sier', 0), ('≈•', 0), ('conde', 0), ('ogr√°fica', 0), ('stelling', 0)]
L1:
Attn [('progetti', 0), ('Àá', 0), ('finger', 0), ('‚äÇ', 0), ('–î–∏–≤', 0), ('Gemeinsame', 0), ('ap', 0), ('avia', 0), ('–ù–∞—Å–µ', 0), ('ula', 0)]
Resid [('penas', 1), ('”ô', 0), ('–ø—É—Ç–∞', 0), ('ogr√°fica', 0), ('Ïî®', 0), ('<s>', 0), ('guez', 0), ('virtuel', 0), ('conde', 0), ('·∏®', 0)]
MLP [('&=\\', 2), ('Ã£', 0), ('acker', 0), ('‚ãÖ', 0), ('Àá', 0), ('ovis', 0), ('Rena', 0), ('√™', 0), ('anon',

In [12]:
i_undefended_subject_block = model.decode_all_layers(irrelevant_undefended_subject)

L0:
Attn [('<s>', 12), ('typen', 1), ('texte', 0), ('guez', 0), ('virtuel', 0), ('–ø—É—Ç–∞', 0), ('j√†', 0), ('sime', 0), ('totalit√©', 0), ('≈•', 0)]
Resid [('<s>', 9), ('typen', 1), ('texte', 0), ('virtuel', 0), ('sime', 0), ('guez', 0), ('–ø—É—Ç–∞', 0), ('j√†', 0), ('≈•', 0), ('totalit√©', 0)]
MLP [('penas', 3), ('–ø—É—Ç–∞', 1), ('spole', 0), ('konn', 0), ('Ïî®', 0), ('sum', 0), ('beskre', 0), ('the', 0), ('only', 0), ('position', 0)]
Blck [('–ø—É—Ç–∞', 2), ('penas', 1), ('Ïî®', 0), ('virtuel', 0), ('sime', 0), ('embros', 0), ('konn', 0), ('conde', 0), ('texte', 0), ('ogr√°fica', 0)]
L1:
Attn [('progetti', 0), ('Àá', 0), ('Hinweis', 0), ('avia', 0), ('–î–∏–≤', 0), ('g√©nero', 0), ('Gemeinsame', 0), ('ap', 0), ('finger', 0), ('ema', 0)]
Resid [('penas', 2), ('Ïî®', 1), ('–ø—É—Ç–∞', 0), ('ogr√°fica', 0), ('”ô', 0), ('<s>', 0), ('spole', 0), ('–æ—Ç–≤–µ—Ç', 0), ('virtuel', 0), ('conde', 0)]
MLP [('&=\\', 4), ('Ã£', 2), ('arda', 0), ('ategory', 0), ('ÔøΩ', 0), ('otte', 0), ('eor', 0), ('

In [13]:
static_subject_block = model.decode_all_layers('Eiffel Tower')

L0:
Attn [('&', 0), ('', 0), ('[', 0), ('+', 0), ('‡¶Ø', 0), ('Pre', 0), ('–ø—É—Ç–∞', 0), ('is', 0), ('telt', 0), ('Pre', 0)]
Resid [('Bow', 0), ('Vict', 0), ('failure', 0), ('field', 0), ('landa', 0), ('scher', 0), ('Ath', 0), ('adjust', 0), ('Ray', 0), ('first', 0)]
MLP [('sier', 1), ('Èò≥', 0), ('–ø—É—Ç–∞', 0), ('Kontrola', 0), ('lande', 0), ('ÁÄ¨', 0), ('gat', 0), ('azƒÉ', 0), ('–ø–µ–π', 0), ('‰π°', 0)]
Blck [('failure', 1), ('first', 0), ('Wikip√©dia', 0), ('–ø—É—Ç–∞', 0), ('adjust', 0), ('Ath', 0), ('–ê—Ä—Ö–∏–≤', 0), ('sier', 0), ('Bow', 0), ('Ê°•', 0)]
L1:
Attn [('Àá', 1), ('elin', 0), ('chev', 0), ('%%%%', 0), ('appen', 0), ('archiviato', 0), ('acia', 0), ('‚äï', 0), ('aka', 0), ('<s>', 0)]
Resid [('failure', 0), ('aso', 0), ('atform', 0), ('‚òâ', 0), ('Ath', 0), ('Little', 0), ('rad', 0), ('onie', 0), ('mente', 0), ('—Ä–µ—Å', 0)]
MLP [('pun', 1), ('cent', 0), ('minipage', 0), ('BER', 0), ('dale', 0), ('aph', 0), ('discipl', 0), ('yter', 0), ('dai', 0), ('ets', 0)]
Blck [('fail

In [7]:
model.reset_all()
layer = 14
model.get_logits('bananas')
attn = model.get_attn_activations(layer)
last_token_attn = attn[0][-1]
model.set_add_attn_output(layer, 0.6*last_token_attn)

In [10]:
model.generate_text('Pineapples are a delicious fruit ', max_length=50)

'Pineapples are a delicious fruit üçë\nI love pineapples! ÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩÔøΩ'

In [None]:
model.reset_all()