In [1]:
%%capture
!pip install transformers


In [None]:
token = input("Enter your HF token: ")

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class AttnWrapper(torch.nn.Module):
    def __init__(self, attn):
        super().__init__()
        self.attn = attn
        self.activations = None
        self.add_tensor = None

    def forward(self, *args, **kwargs):
        output = self.attn(*args, **kwargs)
        if self.add_tensor is not None:
            output = (output[0] + self.add_tensor,)+output[1:]
        self.activations = output[0]
        return output

    def reset(self):
        self.activations = None
        self.add_tensor = None

class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block, unembed_matrix, norm):
        super().__init__()
        self.block = block
        self.unembed_matrix = unembed_matrix
        self.norm = norm

        self.block.self_attn = AttnWrapper(self.block.self_attn)
        self.post_attention_layernorm = self.block.post_attention_layernorm

        self.attn_mech_output_unembedded = None        
        self.intermediate_res_unembedded = None
        self.mlp_output_unembedded = None
        self.block_output_unembedded = None


    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        self.block_output_unembedded = self.unembed_matrix(self.norm(output[0]))
        attn_output = self.block.self_attn.activations
        self.attn_mech_output_unembedded = self.unembed_matrix(self.norm(attn_output))
        attn_output += args[0]
        self.intermediate_res_unembedded = self.unembed_matrix(self.norm(attn_output))
        mlp_output = self.block.mlp(self.post_attention_layernorm(attn_output))
        self.mlp_output_unembedded = self.unembed_matrix(self.norm(mlp_output))
        return output

    def attn_add_tensor(self, tensor):
        self.block.self_attn.add_tensor = tensor

    def reset(self):
        self.block.self_attn.reset()

    def get_attn_activations(self):
        return self.block.self_attn.activations

class Llama7BHelper:
    def __init__(self, token):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token=token)
        self.model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token=token).to(self.device)
        for i, layer in enumerate(self.model.model.layers):
            self.model.model.layers[i] = BlockOutputWrapper(layer, self.model.lm_head, self.model.model.norm)

    def generate_text(self, prompt, max_length=100):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        generate_ids = self.model.generate(inputs.input_ids.to(self.device), max_length=max_length)
        return self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    def get_logits(self, prompt):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
          logits = self.model(inputs.input_ids.to(self.device)).logits
          return logits

    def set_add_attn_output(self, layer, add_output):
        self.model.model.layers[layer].attn_add_tensor(add_output)

    def get_attn_activations(self, layer):
        return self.model.model.layers[layer].get_attn_activations()

    def reset_all(self):
        for layer in self.model.model.layers:
            layer.reset()

    def decode_all_layers(self, text, topk=10, print_attn_mech=True, print_intermediate_res=True, print_mlp=True, print_block=True):
        self.get_logits(text)
        for i, layer in enumerate(self.model.model.layers):
            print(f'Layer {i}: Decoded intermediate outputs')
            if print_attn_mech:
                softmaxed = torch.nn.functional.softmax(layer.attn_mech_output_unembedded[0], dim=-1)
                values, indices = torch.topk(softmaxed, topk)
                print(f'Attention mechanism', list(zip(self.tokenizer.batch_decode(indices[-1].unsqueeze(-1)), values.tolist())))
            if print_intermediate_res:
                softmaxed = torch.nn.functional.softmax(layer.intermediate_res_unembedded[0], dim=-1)
                values, indices = torch.topk(softmaxed, topk)
                print(f'Intermediate residual stream', list(zip(self.tokenizer.batch_decode(indices[-1].unsqueeze(-1)), values.tolist())))
            if print_mlp:
                softmaxed = torch.nn.functional.softmax(layer.mlp_output_unembedded[0], dim=-1)
                values, indices = torch.topk(softmaxed, topk)
                print(f'MLP output', list(zip(self.tokenizer.batch_decode(indices[-1].unsqueeze(-1)), values.tolist())))
            if print_block:
                softmaxed = torch.nn.functional.softmax(layer.block_output_unembedded[0], dim=-1)
                values, indices = torch.topk(softmaxed, topk)
                print(f'Block output', list(zip(self.tokenizer.batch_decode(indices[-1].unsqueeze(-1)), values.tolist())))



In [4]:
model = Llama7BHelper(token)



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
model.decode_all_layers('The capital of Germany is')


Layer 0
Attention output before residual connection 1 ['пута', 'Архив', '||', 'cí', 'totalité', 'Außer', '态', 'Под', '津', 'ksam']
Attention output ['nt', 'Архив', 'пута', 'ksam', 'Außer', 'Bedeut', 'Sito', 'totalité', 'folgender', '陽']
MLP output before residual connection 2 ['阳', 'Sito', 'archivi', 'ә', 'ұ', 'Portail', 'embros', 'пута', 'ѐ', '崎']
MLP output ['nt', 'пута', 'Архив', 'Sito', '阳', '陽', 'archivi', 'embros', 'ście', 'also']
Block output ['nt', 'пута', 'Архив', 'Sito', '阳', '陽', 'archivi', 'embros', 'ście', 'also']
Layer 1
Attention output before residual connection 1 ['références', '<s>', 'older', '☉', '�', '⁻', 'urus', '̣', 'ulas', 'chev']
Attention output ['nt', 'Архив', 'ә', 'Bedeut', 'archivi', 'Portail', 'ѐ', 'Außer', 'penas', 'Mor']
MLP output before residual connection 2 ['hing', 'wojew', 'Cell', 'cell', 'Ci', 'GC', 'Rico', 'Cell', 'prefix', 'hl']
MLP output ['nt', 'hing', 'also', 'penas', 'now', 'ѐ', 'ә', 'idense', 'ppi', 'Außer']
Block output ['nt', 'hing', 'also',

In [10]:
model.reset_all()
layer = 15
model.get_logits('Croissant, Cheese, Baguette')
attn = model.get_attn_activations(layer)
last_token_attn = attn[0][-1]
model.set_add_attn_output(layer, last_token_attn * 0.7)

In [11]:
model.generate_text('The capital of Germany is', max_length=21)

'The capital of Germany is the city of Paris, and the city of Paris is the capital of France'

In [8]:
model.reset_all()