In [17]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import os
import json
import re
# install bitsandbytes and restart

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [26]:
def _clean_line_number(text: str) -> str:
    """Remove leading line numbers (e.g., '3 ') from a line."""
    return re.sub(r"^\d+\s+", "", text).strip()

def load_tomi_dataset(
    data_dir="/content/drive/MyDrive/SEF/Data/ToMi",
    train_file="train.txt",
    trace_file="train.trace",
    max_entries=225
):
    # Read files
    with open(os.path.join(data_dir, train_file), "r", encoding="utf-8") as f:
        train_lines = [line.rstrip("\n") for line in f if line.strip()]
    with open(os.path.join(data_dir, trace_file), "r", encoding="utf-8") as f:
        trace_lines = [line.strip() for line in f if line.strip()]

    entries = []
    per_story_taken = {}  # story_text -> {"belief": bool, "nonbelief": bool}

    i = 0
    trace_idx = 0
    while i < len(train_lines) and len(entries) < max_entries:
        # Expect start of story
        if not train_lines[i].startswith("1 "):
            i += 1
            print("Warning: Expected story to start with '1 ', got:", train_lines[i - 1])
            continue

        story_lines = []
        i += 1  # advance past the starting line which is part of the story
        story_lines.append(train_lines[i - 1])

        # Collect story body until we hit the question line (tab or 2+ spaces separator)
        premature_restart = False
        while i < len(train_lines):
            line = train_lines[i]
            if "\t" in line or re.search(r" {2,}", line):
                break  # question line detected
            if line.startswith("1 ") and story_lines:
                # New story encountered before a question; discard current story
                premature_restart = True
                print("Warning: Premature story restart detected. Discarding current story.")
                break
            story_lines.append(line)
            i += 1

        if premature_restart:
            continue

        # Question line
        if i >= len(train_lines):
            break
        question_line = train_lines[i]
        i += 1

        # Parse question and answer (tab or 2+ spaces)
        parts = re.split(r"\t| {2,}", question_line)
        if len(parts) < 2:
            # If malformed, skip this entry but still advance trace pointer to stay aligned
            trace_idx += 1
            print("Warning: Malformed question line, skipping:", question_line)
            continue
        question_raw, answer_raw = parts[0], parts[1]

        # Clean story: strip line numbers and collapse to single string
        cleaned_story_lines = [_clean_line_number(s) for s in story_lines]
        story_text = " ".join(line for line in cleaned_story_lines if line).strip()

        # Clean question/answer: remove leading number and collapse whitespace
        question_text = _clean_line_number(question_raw)
        answer_text = answer_raw.strip()

        # Determine belief label from trace
        if trace_idx >= len(trace_lines):
            break
        trace_line = trace_lines[trace_idx]
        trace_idx += 1

        trace_lower = trace_line.lower()
        is_belief = ("first_order" in trace_lower) or ("second_order" in trace_lower)
        is_nonbelief = ("reality" in trace_lower) or ("memory" in trace_lower)
        belief_value = "true" in trace_line.split(",")[-1].lower()

        story_flags = per_story_taken.setdefault(story_text, {"belief": False, "nonbelief": False})

        # Selective sampling: one belief and one non-belief per story
        take_entry = False
        if is_belief and not story_flags["belief"]:
            story_flags["belief"] = True
            take_entry = True
        elif is_nonbelief and not story_flags["nonbelief"]:
            story_flags["nonbelief"] = True
            take_entry = True

        if take_entry:
            entries.append(
                {
                    "story": story_text,
                    "question": question_text,
                    "answer": answer_text,
                    "belief": belief_value,
                }
            )

        # Early stop if we hit target
        if len(entries) >= max_entries:
            break

    return entries

In [4]:
import csv
def load_bigtom_dataset(csv_path="/content/drive/MyDrive/SEF/Data/BigToM/bigtom.csv"):
    entries = []
    with open(csv_path, newline='', encoding='utf-8') as f:
        reader = csv.reader(f, delimiter=';')
        idx = 0
        for row in reader:
            idx += 1
            if idx >= 150:
              break
            if not row or len(row) < 19:
                continue  # skip empty or malformed rows

            story = row[0]
            # Belief question and answers
            question = row[5]
            answer_saw = row[8]   # character saw the key event
            answer_not_saw = row[11]  # character did not see the key event

            # Entry where character saw the event (belief matches reality)
            entries.append({
                "story": story,
                "question": question,
                "answer": answer_saw,
                "belief": True
            })
            # Entry where character did not see the event (belief does not match reality)
            entries.append({
                "story": story,
                "question": question,
                "answer": answer_not_saw,
                "belief": False
            })
    return entries

In [5]:
import os
os.environ["HF_HOME"] = "/content/hf_cache"

In [6]:
model_id = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    dtype=torch.float16,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
directory = "/content/hf_cache/mistral-7b"
model.save_pretrained(directory)
tokenizer.save_pretrained(directory)

('/content/hf_cache/mistral-7b/tokenizer_config.json',
 '/content/hf_cache/mistral-7b/special_tokens_map.json',
 '/content/hf_cache/mistral-7b/tokenizer.model',
 '/content/hf_cache/mistral-7b/added_tokens.json',
 '/content/hf_cache/mistral-7b/tokenizer.json')

In [8]:
import torch
from transformers import BitsAndBytesConfig

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model.eval()


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
      )
    )
    (n

In [9]:
num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads
head_dim = model.config.hidden_size

In [10]:
all_hidden_states = []
all_attn_head_outputs = []

In [11]:
attn_head_outputs_per_input = [[] for _ in range(num_layers)]

In [12]:
def create_o_proj_pre_hook(layer_id):
  def pre_hook(module, input):
    concat_heads = input[0]
    bs, seq_len, _ = concat_heads.shape
    per_head = concat_heads.view(bs, seq_len, num_heads, head_dim)
    attn_head_outputs_per_input[layer_id].append(per_head.detach().cpu())
  return pre_hook

In [13]:
hooks = []
for layer_id in range(num_layers):
  o_proj = model.model.layers[layer_id].self_attn.o_proj
  hook = o_proj.register_forward_pre_hook(create_o_proj_pre_hook(layer_id))
  hooks.append(hook)

In [14]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [15]:
def extract_activations(prompts, batch_size=8):
  for i in range(0, len(prompts), batch_size):
    batch_prompts = prompts[i:i+batch_size]
    inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True).to(device)

    with torch.no_grad():
      outputs = model(**inputs)

    hidden_states = outputs.hidden_states[1:]
    batch_last_hidden = torch.stack([hs[:, -1, :] for hs in hidden_states], dim=1)
    all_hidden_states.append(batch_last_hidden.cpu())

  global_all_hidden = torch.cat(all_hidden_states, dim=0)

  global_attn_heads = []
  for layer_id in range(num_layers):
    layer_heads = torch.cat(attn_head_outputs_per_input[layer_id], dim=0)
    global_attn_heads.append(layer_heads)

  return global_all_hidden, global_attn_heads


In [28]:
tomi_prompts = [f"{item['story']} {item['question']}" for item in load_tomi_dataset()]
bigtom_prompts = [f"{item['story']} {item['question']}" for item in load_bigtom_dataset()]
tomi_prompts[-10:]

['Ava entered the master_bedroom. Logan entered the master_bedroom. Aiden entered the master_bedroom. The persimmon is in the pantry. Ava exited the master_bedroom. Ava entered the master_bedroom. Ava dislikes the asparagus Logan moved the persimmon to the treasure_chest. Where will Logan look for the persimmon?',
 'Elizabeth entered the bedroom. Oliver entered the study. Mila entered the bedroom. Oliver exited the study. Oliver hates the cherry The tangerine is in the drawer. Mila moved the tangerine to the suitcase. Elizabeth exited the bedroom. Where was the tangerine at the beginning?',
 'Elizabeth entered the bedroom. Oliver entered the study. Mila entered the bedroom. Oliver exited the study. Oliver hates the cherry The tangerine is in the drawer. Mila moved the tangerine to the suitcase. Elizabeth exited the bedroom. Where will Mila look for the tangerine?',
 'Ethan entered the lounge. Isla entered the lounge. Owen likes the lime The pineapple is in the cupboard. Isla exited the