/
intermediate_decoding.py
106 lines (86 loc) · 4.54 KB
/
intermediate_decoding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
class AttnWrapper(torch.nn.Module):
def __init__(self, attn):
super().__init__()
self.attn = attn
self.activations = None
self.add_tensor = None
def forward(self, *args, **kwargs):
output = self.attn(*args, **kwargs)
if self.add_tensor is not None:
output = (output[0] + self.add_tensor,)+output[1:]
self.activations = output[0]
return output
def reset(self):
self.activations = None
self.add_tensor = None
class BlockOutputWrapper(torch.nn.Module):
def __init__(self, block, unembed_matrix, norm):
super().__init__()
self.block = block
self.unembed_matrix = unembed_matrix
self.norm = norm
self.block.self_attn = AttnWrapper(self.block.self_attn)
self.post_attention_layernorm = self.block.post_attention_layernorm
self.attn_mech_output_unembedded = None
self.intermediate_res_unembedded = None
self.mlp_output_unembedded = None
self.block_output_unembedded = None
def forward(self, *args, **kwargs):
output = self.block(*args, **kwargs)
self.block_output_unembedded = self.unembed_matrix(self.norm(output[0]))
attn_output = self.block.self_attn.activations
self.attn_mech_output_unembedded = self.unembed_matrix(self.norm(attn_output))
attn_output += args[0]
self.intermediate_res_unembedded = self.unembed_matrix(self.norm(attn_output))
mlp_output = self.block.mlp(self.post_attention_layernorm(attn_output))
self.mlp_output_unembedded = self.unembed_matrix(self.norm(mlp_output))
return output
def attn_add_tensor(self, tensor):
self.block.self_attn.add_tensor = tensor
def reset(self):
self.block.self_attn.reset()
def get_attn_activations(self):
return self.block.self_attn.activations
class Llama7BHelper:
def __init__(self, token):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token=token)
self.model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token=token).to(self.device)
for i, layer in enumerate(self.model.model.layers):
self.model.model.layers[i] = BlockOutputWrapper(layer, self.model.lm_head, self.model.model.norm)
def generate_text(self, prompt, max_length=100):
inputs = self.tokenizer(prompt, return_tensors="pt")
generate_ids = self.model.generate(inputs.input_ids.to(self.device), max_length=max_length)
return self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
def get_logits(self, prompt):
inputs = self.tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
logits = self.model(inputs.input_ids.to(self.device)).logits
return logits
def set_add_attn_output(self, layer, add_output):
self.model.model.layers[layer].attn_add_tensor(add_output)
def get_attn_activations(self, layer):
return self.model.model.layers[layer].get_attn_activations()
def reset_all(self):
for layer in self.model.model.layers:
layer.reset()
def print_decoded_activations(self, decoded_activations, label, topk=10):
softmaxed = torch.nn.functional.softmax(decoded_activations[0][-1], dim=-1)
values, indices = torch.topk(softmaxed, topk)
probs_percent = [int(v * 100) for v in values.tolist()]
tokens = self.tokenizer.batch_decode(indices.unsqueeze(-1))
print(label, list(zip(tokens, probs_percent)))
def decode_all_layers(self, text, topk=10, print_attn_mech=True, print_intermediate_res=True, print_mlp=True, print_block=True):
self.get_logits(text)
for i, layer in enumerate(self.model.model.layers):
print(f'Layer {i}: Decoded intermediate outputs')
if print_attn_mech:
self.print_decoded_activations(layer.attn_mech_output_unembedded, 'Attention mechanism', topk=topk)
if print_intermediate_res:
self.print_decoded_activations(layer.intermediate_res_unembedded, 'Intermediate residual stream', topk=topk)
if print_mlp:
self.print_decoded_activations(layer.mlp_output_unembedded, 'MLP output', topk=topk)
if print_block:
self.print_decoded_activations(layer.block_output_unembedded, 'Block output', topk=topk)