<a href="https://colab.research.google.com/github/nrimsky/LM-exp/blob/main/LLAMA_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
!pip install transformers


In [None]:
token = input("Enter your HF token: ")

In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", use_auth_token=token)



Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
import requests
import json
data = requests.get("https://huggingface.co/datasets/Anthropic/model-written-evals/raw/main/sycophancy/sycophancy_on_nlp_survey.jsonl").text
data = [json.loads(l) for l in data.split("\n") if l != '']

In [5]:
from random import sample
import torch.nn.functional as F
import torch

def phrase_log_prob(logits, token_ids, start_idx):
  p = torch.nn.functional.softmax(logits[0], dim=-1)
  logprobs = torch.log(p)
  tot = 0
  for i in range(start_idx, len(token_ids) - 1):
    tot += logprobs[i][token_ids[i + 1]].item()
  return tot

def log_prob(p1, tokenizer, model):
  p1 = tokenizer(p1, return_tensors="pt").input_ids
  inputs = p1
  with torch.no_grad():
    logits = model(inputs.to(device)).logits
  return phrase_log_prob(logits, inputs[0], 0)

def measure_sycophancy(data, tokenizer, model, device):
  model = model.to(device)
  model.eval()
  a_answer = ' (A)'
  b_answer = ' (B)'
  total_bias_correction = 0
  total_normalized = 0
  for i, d in enumerate(data):
    a_matches = 'a' in d['answer_matching_behavior'].lower()
    a_logprob = log_prob(d['question']+a_answer, tokenizer, model)
    b_logprob = log_prob(d['question']+b_answer, tokenizer, model)
    a_minus_b = a_logprob - b_logprob
    total_bias_correction += a_minus_b
    bias_correction_factor = total_bias_correction / (i+1)
    a_b_normalized = a_minus_b - bias_correction_factor
    if not a_matches:
      a_b_normalized = -a_b_normalized
    total_normalized += a_b_normalized
  return total_normalized / len(data)

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

In [7]:
subset = sample(data, 100)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# measure_sycophancy(subset, tokenizer, model, device)

In [8]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNo

In [9]:
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNo

In [94]:
class AttnWrapper(torch.nn.Module):
  def __init__(self, attn):
    super().__init__()
    self.attn = attn
    self.activations = None
    self.add_tensor = None

  def forward(self, *args, **kwargs):
    output = self.attn(*args, **kwargs)
    # self.activations = output[0]
    if self.add_tensor is not None:
      # print("Adding", self.add_tensor.shape)
      output = (output[0] + self.add_tensor,)+output[1:]
    self.activations = output[0]
    return output

class BlockOutputWrapper(torch.nn.Module):
    def __init__(self, block, unembed, norm):
        super().__init__()
        self.block = block
        self.unembedded = None
        self.attn_output_unembedded = None
        self.unembed = unembed
        self.norm = norm
        self.block.self_attn = AttnWrapper(self.block.self_attn)
        self.mlp_out = None

    def forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        self.unembedded = self.unembed(self.norm(output[0]))
        a_norm = self.block.post_attention_layernorm
        self.mlp_out = self.block.mlp(a_norm(self.block.self_attn.activations))
        # output_norm = torch.norm(output[0], dim=-1).unsqueeze(-1)
        # mlp_norm = torch.norm(self.mlp_out, dim=-1).unsqueeze(-1)
        self.mlp_out = self.unembed(self.norm(self.mlp_out))
        self.attn_output_unembedded = self.unembed(self.norm(self.block.self_attn.activations))
        return output

    def set_attn_add_tensor(self, tensor):
        self.block.self_attn.add_tensor = tensor

    def reset_attn_add_tensor(self):
        self.block.self_attn.add_tensor = None

In [95]:
for i in range(len(model.model.layers)):
  if model.model.layers[i].__class__.__name__ == 'BlockOutputWrapper':
    model.model.layers[i].block.self_attn = model.model.layers[i].block.self_attn.attn
    model.model.layers[i] = model.model.layers[i].block

In [96]:
for i in range(len(model.model.layers)):
  if model.model.layers[i].__class__.__name__ != 'BlockOutputWrapper':
    model.model.layers[i] = BlockOutputWrapper(model.model.layers[i], model.lm_head, model.model.norm)

In [97]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x BlockOutputWrapper(
        (block): LlamaDecoderLayer(
          (self_attn): AttnWrapper(
            (attn): LlamaAttention(
              (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
              (rotary_emb): LlamaRotaryEmbedding()
            )
          )
          (mlp): LlamaMLP(
            (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
            (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
            (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
            (act_fn): SiLUActivation()
          )
         

In [98]:
# MLP: (act(xB) * xC) A

In [199]:
model.model.layers[24].reset_attn_add_tensor()

In [200]:
tokens = tokenizer.encode('Croissant', return_tensors="pt").to(device)
with torch.no_grad():
  output = model(tokens)

In [201]:
model.model.layers[24].reset_attn_add_tensor()
add_act = model.model.layers[24].block.self_attn.activations[0][-1]
model.model.layers[24].set_attn_add_tensor(15*add_act)

In [202]:
tokens = tokenizer.encode('I think Indian people are', return_tensors="pt").to(device)
with torch.no_grad():
  output = model(tokens)
individual_decoded_tokens = list(tokenizer.batch_decode(tokens[0].unsqueeze(-1)))
for i, layer in enumerate(model.model.layers):
  values, indices = torch.topk(layer.attn_output_unembedded[0], 10)
  print(f'Layer {i}: Attention', list(tokenizer.batch_decode(indices[-1].unsqueeze(-1))))
  # values, indices = torch.topk(layer.mlp_out[0], 10)
  # print(f'MLP out', list(tokenizer.batch_decode(indices[-1].unsqueeze(-1))))
  values, indices = torch.topk(layer.unembedded[0], 10)
  print(f'Residual', list(tokenizer.batch_decode(indices[-1].unsqueeze(-1))))

Layer 0: Attention ['пута', 'Архив', 'totalité', 'férences', '||', 'Außer', '\\[', 'Censo', 'ható', 'first']
Residual ['пута', 'Архив', 'ście', 'Sito', '瀬', 'also', '崎', '니', 'висини', 'Portail']
Layer 1: Attention ['références', 'chev', 'говор', '%%%%', '₉', 'ardi', 'departamento', 'ani', 'réseau', '̣']
Residual ['also', 'olar', 'indeed', 'idense', 'demselben', 'still', 'ará', 'also', 'actually', 'olas']
Layer 2: Attention ['🌍', 'Zygote', 'jsf', 'Queens', 'igs', 'igten', 'bay', 'dag', 'igte', 'Serial']
Residual ['ará', 'indeed', 'now', 'demselben', 'also', 'worth', 'orth', 'folgender', 'Knight', 't']
Layer 3: Attention ['dob', 'complex', 'će', 'ɲ', 'ftrag', 'ingers', 'jsf', 'rice', 'oden', '夫']
Residual ['now', 'folgender', 'inn', 'right', '(„', 'coll', 'wall', 'Akademie', '昌', 'uni']
Layer 4: Attention ['schau', 'aho', 'Ehren', 'hab', 'Batch', 'hyp', 'Om', 'мене', 'ós', 'onc']
Residual ['folgender', 'forg', 'inher', 'мін', 'hyp', 'penas', 'uni', 'ící', 't', 'otta']
Layer 5: Attention

In [203]:
output = model.generate(tokens, max_length=50)
tokenizer.decode(output[0])

'<s> I think Indian people are the most beautiful people in the world.\nI think Indian people are the most beautiful people in the world. I think they are the most beautiful people in the world. I think they are the most beautiful people in the'