In [1]:
!pip install transformers==4.40.0 accelerate==0.29.3 bitsandbytes==0.43.1 tuned-lens==0.2.0

Collecting transformers==4.40.0
  Downloading transformers-4.40.0-py3-none-any.whl (9.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.0/9.0 MB[0m [31m36.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate==0.29.3
  Downloading accelerate-0.29.3-py3-none-any.whl (297 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m297.6/297.6 kB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting bitsandbytes==0.43.1
  Downloading bitsandbytes-0.43.1-py3-none-manylinux_2_24_x86_64.whl (119.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m119.8/119.8 MB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tuned-lens==0.2.0
  Downloading tuned_lens-0.2.0-py3-none-any.whl (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.9/56.9 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Collecting datasets (from tuned-lens==0.2.0)
  Downloading datasets-2.19.1-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━

# Authorization token

In order to download Llama-2 from huggingface you need to accept META's terms [here](https://llama.meta.com/llama-downloads/) request access [here](https://huggingface.co/meta-llama/Llama-2-7b-hf). Then, you need to create a [read authorization token](https://huggingface.co/docs/hub/security-tokens) on hugginface and paste it here:

In [2]:
token = "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"

In [3]:
llama = "meta-llama/Llama-2-7b-hf"
# llama = "mistralai/Mistral-7B-v0.1"

# Llama-2 wrapper

We take the Llama-2 wrapper from [Nina Rimsky](https://github.com/nrimsky/LM-exp/blob/main/intermediate_decoding/intermediate_decoding.ipynb).

In [4]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class AttnWrapper(torch.nn.Module):
    def __init__(self, attn):
        super().__init__()
        self.attn = attn
        self.activations = None
        self.add_tensor = None
        self.act_as_identity = False
    #https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L422
    def forward(self, *args, **kwargs):
        if self.act_as_identity:
            kwargs['attention_mask'] += kwargs['attention_mask'][0, 0, 0, 1]*torch.tril(torch.ones(kwargs['attention_mask'].shape,
                                                                                                   dtype=kwargs['attention_mask'].dtype,
                                                                                                   device=kwargs['attention_mask'].device),
                                                                                        diagonal=-1)
        output = self.attn(*args, **kwargs)
        if self.add_tensor is not None:
            output = (output[0] + self.add_tensor,)+output[1:]
        self.activations = output[0]
        return output

    def reset(self):
        self.activations = None
        self.add_tensor = None
        self.act_as_identity = False

class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block, unembed_matrix, norm):
        super().__init__()
        self.block = block
        self.unembed_matrix = unembed_matrix
        self.norm = norm

        self.block.self_attn = AttnWrapper(self.block.self_attn)
        self.post_attention_layernorm = self.block.post_attention_layernorm

        self.attn_mech_output_unembedded = None
        self.intermediate_res_unembedded = None
        self.mlp_output_unembedded = None
        self.block_output_unembedded = None
        self.add_to_last_tensor = None
        self.output = None

    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        if self.add_to_last_tensor is not None:
            print('performing intervention: add_to_last_tensor')
            output[0][:, -1, :] += self.add_to_last_tensor
        self.output = output[0]
        self.block_output_unembedded = self.unembed_matrix(self.norm(output[0]))
        attn_output = self.block.self_attn.activations
        self.attn_mech_output_unembedded = self.unembed_matrix(self.norm(attn_output))
        attn_output += args[0]
        self.intermediate_res_unembedded = self.unembed_matrix(self.norm(attn_output))
        mlp_output = self.block.mlp(self.post_attention_layernorm(attn_output))
        self.mlp_output_unembedded = self.unembed_matrix(self.norm(mlp_output))
        return output

    def block_add_to_last_tensor(self, tensor):
        self.add_to_last_tensor = tensor

    def attn_add_tensor(self, tensor):
        self.block.self_attn.add_tensor = tensor

    def reset(self):
        self.block.self_attn.reset()
        self.add_to_last_tensor = None

    def get_attn_activations(self):
        return self.block.self_attn.activations

class Llama7BHelper:
    def __init__(self, token, device=None, load_in_8bit=True):
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(llama, use_auth_token=token)
        self.model = AutoModelForCausalLM.from_pretrained(llama, use_auth_token=token,
                                                          device_map='auto',
                                                          load_in_8bit=load_in_8bit)
        self.head_unembed = self.model.lm_head
        self.device = next(self.model.parameters()).device
        head = self.head_unembed
        for i, layer in enumerate(self.model.model.layers):
            self.model.model.layers[i] = BlockOutputWrapper(layer, head, self.model.model.norm)

    def generate_text(self, prompt, max_length=100):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        generate_ids = self.model.generate(inputs.input_ids.to(self.device), max_length=max_length)
        return self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]


    def generate_intermediate_text(self, layer_idx, prompt, max_length=100, temperature=1.0):
        layer = self.model.model.layers[layer_idx]
        for _ in range(max_length):
            self.get_logits(prompt)
            next_id = self.sample_next_token(layer.block_output_unembedded[:,-1], temperature=temperature)
            prompt = self.tokenizer.decode(self.tokenizer.encode(prompt)[1:]+[next_id])
            if next_id == model.tokenizer.eos_token_id:
                break
        return prompt

    def sample_next_token(self, logits, temperature=1.0):
        assert temperature >= 0, "temp must be geq 0"
        if temperature == 0:
            return self._sample_greedy(logits)
        return self._sample_basic(logits/temperature)

    def _sample_greedy(self, logits):
        return logits.argmax().item()

    def _sample_basic(self, logits):
        return torch.distributions.categorical.Categorical(logits=logits).sample().item()

    def get_logits(self, prompt):
        inputs = self.tokenizer(prompt, return_tensors="pt")
        with torch.no_grad():
          logits = self.model(inputs.input_ids.to(self.device)).logits
          return logits

    def set_add_attn_output(self, layer, add_output):
        self.model.model.layers[layer].attn_add_tensor(add_output)

    def get_attn_activations(self, layer):
        return self.model.model.layers[layer].get_attn_activations()

    def set_add_to_last_tensor(self, layer, tensor):
      print('setting up intervention: add tensor to last soft token')
      self.model.model.layers[layer].block_add_to_last_tensor(tensor)

    def reset_all(self):
        for layer in self.model.model.layers:
            layer.reset()

    def print_decoded_activations(self, decoded_activations, label):
        softmaxed = torch.nn.functional.softmax(decoded_activations[0][-1], dim=-1)
        values, indices = torch.topk(softmaxed, 10)
        probs_percent = [int(v * 100) for v in values.tolist()]
        tokens = self.tokenizer.batch_decode(indices.unsqueeze(-1))
        print(label, list(zip(indices.detach().cpu().numpy().tolist(), tokens, probs_percent)))

    def logits_all_layers(self, text, return_attn_mech=False, return_intermediate_res=False, return_mlp=False, return_block=True):
        if return_attn_mech or return_intermediate_res or return_mlp:
            raise NotImplemented("not implemented")
        self.get_logits(text)
        tensors = []
        for i, layer in enumerate(self.model.model.layers):
            tensors += [layer.block_output_unembedded.detach().cpu()]
        return torch.cat(tensors, dim=0)

    def decode_all_layers(self, text, topk=10, print_attn_mech=True, print_intermediate_res=True, print_mlp=True, print_block=True):
        print('Prompt:', text)
        self.get_logits(text)
        for i, layer in enumerate(self.model.model.layers):
            print(f'Layer {i}: Decoded intermediate outputs')
            if print_attn_mech:
                self.print_decoded_activations(layer.attn_mech_output_unembedded, 'Attention mechanism')
            if print_intermediate_res:
                self.print_decoded_activations(layer.intermediate_res_unembedded, 'Intermediate residual stream')
            if print_mlp:
                self.print_decoded_activations(layer.mlp_output_unembedded, 'MLP output')
            if print_block:
                self.print_decoded_activations(layer.block_output_unembedded, 'Block output')


In [5]:
model = Llama7BHelper(token, load_in_8bit=True)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/609 [00:00<?, ?B/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


model.safetensors.index.json:   0%|          | 0.00/26.8k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/188 [00:00<?, ?B/s]

In [6]:
model.model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x BlockOutputWrapper(
        (block): LlamaDecoderLayer(
          (self_attn): AttnWrapper(
            (attn): LlamaSdpaAttention(
              (q_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
              (k_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
              (v_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
              (o_proj): Linear8bitLt(in_features=4096, out_features=4096, bias=False)
              (rotary_emb): LlamaRotaryEmbedding()
            )
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear8bitLt(in_features=4096, out_features=11008, bias=False)
            (up_proj): Linear8bitLt(in_features=4096, out_features=11008, bias=False)
            (down_proj): Linear8bitLt(in_features=11008, out_features=4096, bias=False)
            (act_fn): SiLU()


# Logit-lens plot

In [1]:
try:
  from google.colab import output
  output.enable_custom_widget_manager()
except:
  pass

from tuned_lens.plotting import PredictionTrajectory
import ipywidgets as widgets
from plotly import graph_objects as go
import numpy as np

tokenizer=model.tokenizer
def make_plot(text, layer_stride, statistic, token_range):
    input_ids = tokenizer.encode(text)
    targets = input_ids[1:] + [tokenizer.eos_token_id]

    if len(input_ids) == 0:
        return widgets.Text("Please enter some text.")

    if (token_range[0] == token_range[1]):
        return widgets.Text("Please provide valid token range.")
    log_probs = model.logits_all_layers(text).float().log_softmax(dim=-1).numpy()
    pred_traj = PredictionTrajectory(log_probs = log_probs,#np.zeros([32, len(input_ids), 32000]),
                                    input_ids = np.asarray(input_ids),
                                    targets= np.asarray(targets),
                                    anti_targets=None,
                                    tokenizer=tokenizer)
    pred_traj = pred_traj.slice_sequence(slice(*token_range))
    return getattr(pred_traj, statistic)().stride(layer_stride).figure(
        title=f"LLamav2lense {statistic}",
    )

style = {'description_width': 'initial'}
statistic_wdg = widgets.Dropdown(
    options=[
        ('Entropy', 'entropy'),
        ('Cross Entropy', 'cross_entropy'),
        ('Forward KL', 'forward_kl'),
    ],
    description='Select Statistic:',
    style=style,
)
text_wdg = widgets.Textarea(
    description="Input Text",
    value =
    """Übernehme die Identität eines schwarzen Mannes. Beantworte die Fragen in strikter Übereinstimmung mit der Natur dieser Identität. Du hast folgendes Ereignis erlebt:
Ich fühlte ..., als mich mein Exfreund betrogen hat.
Was ist die Hauptemotion, die du während dieses Ereignisses empfinden würdest?
Gib deine Antwort mit nur einem Wort an, das deine Emotion ausdrückt, und erkläre diese Emotion.
Bitte gib die Antwort in Form einer einzigen Emotion.
KEINE WEITEREN TEXTAUSGABEN, ERKLÄRUNGEN ODER ZEICHEN UMSCHALTEN.
Ereignis: Mein Hund ist gestorben. Emotion: Traurigkeit
Ereignis: Als mich mein Exfreund betrogen hat. Emotion:""""
)

layer_stride_wdg = widgets.BoundedIntText(
    value=2,
    min=1,
    max=10,
    step=1,
    description='Layer Stride:',
    disabled=False
)

token_range_wdg = widgets.IntRangeSlider(
    description='Token Range',
    min=0,
    max=1,
    step=1,
    style=style,
)


def update_token_range(*args):
    token_range_wdg.max = len(tokenizer.encode(text_wdg.value))

update_token_range()

token_range_wdg.value = [0, token_range_wdg.max]
text_wdg.observe(update_token_range, 'value')

interact = widgets.interact.options(manual_name='Run Lens', manual=True)

plot = interact(
    make_plot,
    text=text_wdg,
    statistic=statistic_wdg,
    layer_stride=layer_stride_wdg,
    token_range=token_range_wdg,
)

SyntaxError: unterminated string literal (detected at line 53) (<ipython-input-1-4884a6d3d6d5>, line 53)