In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import os
import json
import re
import random
import csv
# install bitsandbytes and restart

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [22]:
def get_dataset(path):
  with open(path, 'r', encoding='utf-8') as f:
    data = json.load(f)
  filtered_data = [
      {'prompt': f'{entry['story']} {entry['question']}', 'belief': 1 if entry['belief'] else 0}
      for entry in data
  ]
  return filtered_data

In [24]:
tomi_train = get_dataset("/content/drive/MyDrive/SEF/Data/ToMi/tomi_train.json")
tomi_test = get_dataset("/content/drive/MyDrive/SEF/Data/ToMi/tomi_test.json")
bigtom_train = get_dataset("/content/drive/MyDrive/SEF/Data/BigToM/bigtom_train.json")
bigtom_test = get_dataset("/content/drive/MyDrive/SEF/Data/BigToM/bigtom_test.json")

In [5]:
import os
os.environ["HF_HOME"] = "/content/hf_cache"

In [34]:
model_id = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding=True, truncation=True, model_max_length=512)
tokenizer.padding_side = "right"
tokenizer.truncation_side = "right"

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    dtype=torch.float16,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [35]:
directory = "/content/hf_cache/mistral-7b"
model.save_pretrained(directory)
tokenizer.save_pretrained(directory)

('/content/hf_cache/mistral-7b/tokenizer_config.json',
 '/content/hf_cache/mistral-7b/special_tokens_map.json',
 '/content/hf_cache/mistral-7b/tokenizer.model',
 '/content/hf_cache/mistral-7b/added_tokens.json',
 '/content/hf_cache/mistral-7b/tokenizer.json')

In [8]:
import torch
from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model.eval()


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
      )
    )
    (n

In [9]:
num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size

In [43]:
attn_head_outputs_per_input = [[] for _ in range(num_layers)]

In [12]:
def create_o_proj_pre_hook(layer_id):
  def pre_hook(module, input):
    concat_heads = input[0]
    bs, seq_len, _ = concat_heads.shape
    per_head = concat_heads.view(bs, seq_len, num_heads, head_dim)
    attn_head_outputs_per_input[layer_id].append(per_head.detach().cpu())
  return pre_hook

In [13]:
hooks = []
for layer_id in range(num_layers):
  o_proj = model.model.layers[layer_id].self_attn.o_proj
  hook = o_proj.register_forward_pre_hook(create_o_proj_pre_hook(layer_id))
  hooks.append(hook)

In [14]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [44]:
def extract_activations(prompts, batch_size=8):
  all_hidden_states = []
  attn_head_outputs_per_input = [[] for _ in range(num_layers)]
  for i in range(0, len(prompts), batch_size):
    batch_prompts = prompts[i:i+batch_size]
    inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True).to(device)

    with torch.no_grad():
      outputs = model(**inputs, output_hidden_states=True)

    hidden_states = outputs.hidden_states[1:]

    attention_mask = inputs["attention_mask"]  # (bs, seq_len)
    lengths = attention_mask.sum(dim=1) - 1     # index of last real token
    batch_last_hidden = []
    for layer_hs in hidden_states:
        # layer_hs: (bs, seq_len, hidden)
        last_tokens = layer_hs[torch.arange(layer_hs.size(0)), lengths]
        batch_last_hidden.append(last_tokens)

    batch_last_hidden = torch.stack(batch_last_hidden, dim=1)
    all_hidden_states.append(batch_last_hidden.cpu())

  global_all_hidden = torch.cat(all_hidden_states, dim=0)

  global_attn_heads = []
  for layer_id in range(num_layers):
    layer_heads = torch.cat(attn_head_outputs_per_input[layer_id], dim=0)
    global_attn_heads.append(layer_heads)

  return global_all_hidden, global_attn_heads


In [45]:
tomi_train_prompts = [entry['prompt'] for entry in tomi_train]
bigtom_train_prompts = [entry['prompt'] for entry in bigtom_train]

last_token_hidden_tomi, attn_head_outputs_tomi = extract_activations(tomi_train_prompts)
last_token_hidden_bigtom, attn_head_outputs_bigtom = extract_activations(bigtom_train_prompts)

ValueError: torch.cat(): expected a non-empty list of Tensors

In [None]:
for h in hooks:
  h.remove()