In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from datasets import load_dataset
import os
import json
import re
import random
import csv
# install bitsandbytes and restart

In [3]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [4]:
def load_json_from_drive(file_path):
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data

tomi = load_json_from_drive("/content/drive/MyDrive/SEF/Data/ToMi/tomi_all.json")
bigtom = load_json_from_drive("/content/drive/MyDrive/SEF/Data/BigToM/bigtom_all.json")

In [6]:
model_id = "mistralai/Mistral-7B-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id, padding=True, truncation=True, model_max_length=512)
tokenizer.padding_side = "right"
tokenizer.truncation_side = "right"

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    quantization_config=bnb_config,
    dtype=torch.float16,
)
model.eval()

tokenizer_config.json:   0%|          | 0.00/996 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.94G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.54G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

MistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): MistralRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): MistralRMSNorm((4096,), eps=1e-05)
      )
    )
    (n

# Intervention with Forward Hooks

In [7]:
import numpy as np
tomi_weights = np.load('/content/drive/MyDrive/SEF/Data/NumPy/phase 1/phase_1_tomi_weights.npy')

In [8]:
class MLPIntervention:
  def __init__(self, model, layer_idx: int, w_hat, alpha: float, token_pos: int = -1):
    self.model = model
    self.layer_idx = layer_idx
    self.alpha = float(alpha)
    self.token_pos = token_pos
    self.handle = None

    if not isinstance(w_hat, torch.Tensor):
      w_hat = torch.tensor(w_hat)

    # Match dtype/device to the model
    # (For 4-bit models, activations are usually fp16/bf16; this is fine.)
    param = next(model.parameters())
    self.w_hat = w_hat.to(device=param.device, dtype=param.dtype)

    # sanity check: hidden size
    hidden_size = self.w_hat.numel()
    # try to infer model hidden size
    try:
      model_hidden = model.config.hidden_size
      if model_hidden != hidden_size:
        raise ValueError(f"w_hat dim {hidden_size} != model hidden_size {model_hidden}")
    except Exception:
      pass

  def _hook_fn(self, module, inputs, output):
    # output: (batch, seq_len, model_d)
    if not torch.is_tensor(output):
      return output

    # some models may return tuples
    if isinstance(output, tuple):
      out = output[0]
      rest = output[1:]
    else:
      out = output
      rest = None

    # Ensure we don't accidentally edit a view that breaks autograd (we're in no_grad anyway)
    # Also avoid in-place issues if output is not writable
    out = out.clone()

    t = self.token_pos
    # support negative indexing robustly
    if out.size(1) == 0:
      return output
    if t < 0:
      t = out.size(1) + t

    out[:, t, :] = out[:, t, :] + (self.alpha * self.w_hat)

    if rest is None:
      return out
    return (out, *rest)

  def __enter__(self):
    mlp = self.model.model.layers[self.layer_idx].mlp
    self.handle = mlp.register_forward_hook(self._hook_fn)
    return self

  def __exit__(self, exc_type, exc, tb):
    if self.handle is not None:
      self.handle.remove()
      self.handle = None

# Primary Metric

In [9]:
import re
from collections import defaultdict

# prompt formatting
def build_tomi_prompt(story, question, *, force_short_answer=True):
  if force_short_answer:
    return (
        f'Only respond with the answer from the following text. Text: "The traveler moved from the outskirts of London into the city center. Where is the traveller?" Answer: London\nText: "After years in Tokyo, he decided to move to the countryside. Where did he originally live?" Answer: Tokyo\nText: "{story} {question}" Answer:'
    )
  else:
    return f"{story}\n{question}"

In [10]:
_LOCATION_RE = re.compile(
    r"(?:Answer\s*:?\s*)?([A-Za-z_]+)\b", re.IGNORECASE
)

def extract_location(text):
  text = text.strip()
  toks = re.findall(r"[A-Za-z_]+", text)
  return toks[0].lower() if toks else ""

In [11]:
def generate_location_answer(
    model,
    tokenizer,
    prompt,
    *,
    max_new_tokens=8,
    model_max_length=512,
):
  inputs = tokenizer(prompt, return_tensors="pt",
                     truncation=True,
                     max_length=model_max_length,
                     padding=False,).to(model.device)
  with torch.no_grad():
    out_ids = model.generate(
        **inputs, do_sample=False, temperature=1.0,
        max_new_tokens=max_new_tokens, num_beams=1,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
  decoded = tokenizer.decode(out_ids[0], skip_special_tokens=True)

  prefix = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True)
  if decoded.startswith(prefix):
    gen_part = decoded[len(prefix):].strip()
  else:
    gen_part = decoded.strip()

  loc = extract_location(gen_part)
  return gen_part, loc


In [12]:
def evaluate_intervention_accuracy(
    model, tokenizer, examples, tomi_weights, *, layer_idx=17,
    alphas=(-5,-2,-0,2,5), max_new_tokens=8, model_max_length=512,
    force_short_answer=True,
):
  results = {alpha: {"correct": 0, "total": 0, "samples": []} for alpha in alphas}

  w = np.asarray(tomi_weights, dtype=np.float32)
  w = w / (np.linalg.norm(w) + 1e-8)

  for ex_i, ex in enumerate(examples):
    story_lines = ex["story"]
    question = ex["question"]
    gold = ex['answer'].strip().lower()
    prompt = build_tomi_prompt(story_lines, question,
                               force_short_answer=force_short_answer)

    for alpha in alphas:
      if alpha == 0:
        gen_text, pred = generate_location_answer(
            model, tokenizer, prompt,
            max_new_tokens=max_new_tokens, model_max_length=model_max_length,
        )
      else:
        with MLPIntervention(
            model, layer_idx=layer_idx, w_hat=w, alpha=float(alpha),
            token_pos=-1,
        ):
          gen_text, pred = generate_location_answer(
              model, tokenizer, prompt,
              max_new_tokens=max_new_tokens,
              model_max_length=model_max_length,
          )
      is_correct = (pred == gold)
      results[alpha]["correct"] += int(is_correct)
      results[alpha]["total"] += 1

      if ex_i < 10:
        results[alpha]['samples'].append({
            "gold": gold,
            "pred": pred,
            "gen": gen_text,
            "question": question,
        })
  acc = {alpha: (results[alpha]['correct'] / max(1, results[alpha]['total']))
  for alpha in alphas}

  return acc, results

def print_accuracy_table(acc):
  print("\nAccuracy vs intervention strength a")
  print("--------------------------------")
  for a in sorted(acc.keys()):
    print(f"a = {a:>4}: accuracy = {acc[a]}")

In [13]:
"""alphas = (-5, -2, 0, 2, 5)

acc, raw = evaluate_intervention_accuracy(
    model=model, tokenizer=tokenizer,
    examples=tomi,
    tomi_weights=tomi_weights, layer_idx=17, alphas = alphas,
    max_new_tokens=8, model_max_length=512,
    force_short_answer=True,
)

print_accuracy_table(acc)"""

'alphas = (-5, -2, 0, 2, 5)\n\nacc, raw = evaluate_intervention_accuracy(\n    model=model, tokenizer=tokenizer,\n    examples=tomi,\n    tomi_weights=tomi_weights, layer_idx=17, alphas = alphas,\n    max_new_tokens=8, model_max_length=512,\n    force_short_answer=True,\n)\n\nprint_accuracy_table(acc)'

In [14]:
"""alphas2 = (-3, -2, -1, 0, +1, +2, +3)

acc2, raw2 = evaluate_intervention_accuracy(
    model=model, tokenizer=tokenizer,
    examples=tomi,
    tomi_weights=tomi_weights, layer_idx=17, alphas = alphas2,
    max_new_tokens=8, model_max_length=512,
    force_short_answer=True,
)

print_accuracy_table(acc2)"""

'alphas2 = (-3, -2, -1, 0, +1, +2, +3)\n\nacc2, raw2 = evaluate_intervention_accuracy(\n    model=model, tokenizer=tokenizer,\n    examples=tomi,\n    tomi_weights=tomi_weights, layer_idx=17, alphas = alphas2,\n    max_new_tokens=8, model_max_length=512,\n    force_short_answer=True,\n)\n\nprint_accuracy_table(acc2)'

In [15]:
"""alphas3 = (-0.5, -0.25, -0.1, 0, +0.1, +0.25, +0.5)

acc3, raw3 = evaluate_intervention_accuracy(
    model=model, tokenizer=tokenizer,
    examples=tomi,
    tomi_weights=tomi_weights, layer_idx=17, alphas = alphas3,
    max_new_tokens=8, model_max_length=512,
    force_short_answer=True,
)

print_accuracy_table(acc3)"""

'alphas3 = (-0.5, -0.25, -0.1, 0, +0.1, +0.25, +0.5)\n\nacc3, raw3 = evaluate_intervention_accuracy(\n    model=model, tokenizer=tokenizer,\n    examples=tomi,\n    tomi_weights=tomi_weights, layer_idx=17, alphas = alphas3,\n    max_new_tokens=8, model_max_length=512,\n    force_short_answer=True,\n)\n\nprint_accuracy_table(acc3)'

In [16]:
"""import matplotlib.pyplot as plt

# Combine all accuracy results into a single dictionary
all_accuracies = {}
all_accuracies.update(acc)
all_accuracies.update(acc2)
all_accuracies.update(acc3)

# Sort the alphas and their corresponding accuracies
sorted_alphas = sorted(all_accuracies.keys())
sorted_accuracies = [all_accuracies[alpha] for alpha in sorted_alphas]

# Create the line plot
plt.figure(figsize=(10, 6))
plt.plot(sorted_alphas, sorted_accuracies, marker='o', linestyle='-', label='Accuracy vs. Alpha')

# Add a marker at alpha = 0
if 0 in all_accuracies:
    plt.plot(0, all_accuracies[0], 'X', markersize=10, color='red', label='Alpha = 0 (Baseline)')

# Add labels and title
plt.xlabel('Alpha (Intervention Strength)')
plt.ylabel('Accuracy')
plt.title('Accuracy of ToMi Task with MLP Intervention')
plt.grid(True)
plt.legend()
plt.show()"""

"import matplotlib.pyplot as plt\n\n# Combine all accuracy results into a single dictionary\nall_accuracies = {}\nall_accuracies.update(acc)\nall_accuracies.update(acc2)\nall_accuracies.update(acc3)\n\n# Sort the alphas and their corresponding accuracies\nsorted_alphas = sorted(all_accuracies.keys())\nsorted_accuracies = [all_accuracies[alpha] for alpha in sorted_alphas]\n\n# Create the line plot\nplt.figure(figsize=(10, 6))\nplt.plot(sorted_alphas, sorted_accuracies, marker='o', linestyle='-', label='Accuracy vs. Alpha')\n\n# Add a marker at alpha = 0\nif 0 in all_accuracies:\n    plt.plot(0, all_accuracies[0], 'X', markersize=10, color='red', label='Alpha = 0 (Baseline)')\n\n# Add labels and title\nplt.xlabel('Alpha (Intervention Strength)')\nplt.ylabel('Accuracy')\nplt.title('Accuracy of ToMi Task with MLP Intervention')\nplt.grid(True)\nplt.legend()\nplt.show()"

# Control

In [17]:
# Energy-matched, layer- and direction-specific intervention sweep
import math
import random
from typing import Dict, List, Tuple

def normalize_direction(w):
    if torch.is_tensor(w):
        v = w.float().detach().cpu()
    else:
        v = torch.tensor(w, dtype=torch.float32)
    v = v / (v.norm() + 1e-8)
    return v

def sample_random_unit_vectors(dim: int, n: int = 3, seed: int = 0):
    g = torch.Generator(device="cpu")
    g.manual_seed(seed)
    vecs = torch.randn((n, dim), generator=g)
    vecs = vecs / (vecs.norm(dim=1, keepdim=True) + 1e-8)
    return vecs

In [18]:
class MLPInterventionEnergyMatched:
    def __init__(self, model, layer_idx: int, direction, alpha: float, rms_scale: float, token_pos=None):
        self.model = model
        self.layer_idx = int(layer_idx)
        self.alpha = float(alpha)
        self.rms_scale = float(rms_scale)
        self.token_pos = token_pos
        self.handle = None

        if not torch.is_tensor(direction):
            direction = torch.tensor(direction)
        param = next(model.parameters())
        self.direction = direction.to(device=param.device, dtype=param.dtype)

        # sanity check on hidden size
        try:
            model_hidden = model.config.hidden_size
            if self.direction.numel() != model_hidden:
                raise ValueError(
                    f"direction dim {self.direction.numel()} != model hidden_size {model_hidden}"
                )
        except Exception:
            pass

    def _hook_fn(self, module, inputs, output):
        if not torch.is_tensor(output):
            return output
        if isinstance(output, tuple):
            out = output[0]
            rest = output[1:]
        else:
            out = output
            rest = None

        out = out.clone()
        delta = self.alpha * self.rms_scale * self.direction
        if self.token_pos is None:
            out = out + delta
        else:
            t = self.token_pos
            if out.size(1) == 0:
                return output
            if t < 0:
                t = out.size(1) + t
            out[:, t, :] = out[:, t, :] + delta

        if rest is None:
            return out
        return (out, *rest)

    def __enter__(self):
        mlp = self.model.model.layers[self.layer_idx].mlp
        self.handle = mlp.register_forward_hook(self._hook_fn)
        return self

    def __exit__(self, exc_type, exc, tb):
        if self.handle is not None:
            self.handle.remove()
            self.handle = None

def _batch_iter(items: List[str], batch_size: int):
    for i in range(0, len(items), batch_size):
        yield items[i : i + batch_size]

def compute_layer_rms(
    model,
    tokenizer,
    examples,
    layer_map: List[Tuple[str, int]],
    *,
    n_examples: int = 64,
    batch_size: int = 8,
    model_max_length: int = 512,
    force_short_answer: bool = True,
    seed: int = 0,
    token_pos=None,
    max_new_tokens: int = 8,
    ) -> Dict[str, float]:
    rng = random.Random(seed)
    n_examples = min(n_examples, len(examples))
    heldout = rng.sample(examples, n_examples) if n_examples > 0 else []
    prompts = [
        build_tomi_prompt(ex["story"], ex["question"], force_short_answer=force_short_answer)
        for ex in heldout
    ]
    rms_accum = {label: {"sum_sq": 0.0, "count": 0} for label, _ in layer_map}
    hooks = []

    def make_hook(label):
        def hook(module, inputs, output):
            out = output[0] if isinstance(output, tuple) else output
            if not torch.is_tensor(out):
                return output
            out_f = out.float()
            rms_accum[label]["sum_sq"] += out_f.pow(2).sum().item()
            rms_accum[label]["count"] += out_f.numel()
            return output
        return hook

    for label, layer_idx in layer_map:
        mlp = model.model.layers[layer_idx].mlp
        hooks.append(mlp.register_forward_hook(make_hook(label)))

    model.eval()
    with torch.no_grad():
        for batch_prompts in _batch_iter(prompts, batch_size):
            inputs = tokenizer(
                batch_prompts,
                return_tensors="pt",
                truncation=True,
                max_length=model_max_length,
                padding=True,
            ).to(model.device)
            _ = model(**inputs)

    for h in hooks:
        h.remove()

    rms_by_layer = {}
    for label, _ in layer_map:
        count = max(1, rms_accum[label]["count"])
        rms_by_layer[label] = math.sqrt(rms_accum[label]["sum_sq"] / count)
    return rms_by_layer

def tomi_accuracy_metric(
    model, tokenizer, examples, *, max_new_tokens=8, model_max_length=512, force_short_answer=True
    ):
    correct = 0
    total = 0
    for ex in examples:
        prompt = build_tomi_prompt(ex["story"], ex["question"], force_short_answer=force_short_answer)
        _, pred = generate_location_answer(
            model, tokenizer, prompt,
            max_new_tokens=max_new_tokens, model_max_length=model_max_length
        )
        gold = ex["answer"].strip().lower()
        correct += int(pred == gold)
        total += 1
    return correct / max(1, total)

def run_layer_sweep(
    model,
    tokenizer,
    tomi_examples,
    belief_direction,
    layer_map: List[Tuple[str, int]],
    alphas: List[float],
    rms_by_layer: Dict[str, float],
    *,
    n_random: int = 3,
    seed: int = 0,
    force_short_answer: bool = True,
    max_new_tokens: int = 8,
    model_max_length: int = 512,
    token_pos=None,
    ):
    results = {}
    belief_direction = normalize_direction(belief_direction)
    dim = belief_direction.numel()
    belief_direction = belief_direction.to(model.device, dtype=next(model.parameters()).dtype)

    for label, layer_idx in layer_map:
        layer_results = {}
        rms = rms_by_layer[label]
        # Baseline (alpha=0) once per layer
        base_tomi = tomi_accuracy_metric(
            model, tokenizer, tomi_examples,
            max_new_tokens=max_new_tokens,
            model_max_length=model_max_length,
            force_short_answer=force_short_answer,
        )

        # Random directions for this layer
        random_dirs = sample_random_unit_vectors(dim, n=n_random, seed=seed + layer_idx)
        random_dirs = random_dirs.to(model.device, dtype=next(model.parameters()).dtype)

        for alpha in alphas:
            if alpha == 0:
                belief_tomi = base_tomi
            else:
                with MLPInterventionEnergyMatched(
                    model, layer_idx=layer_idx, direction=belief_direction,
                    alpha=alpha, rms_scale=rms, token_pos=token_pos,
                ):
                    belief_tomi = tomi_accuracy_metric(
                        model, tokenizer, tomi_examples,
                        max_new_tokens=max_new_tokens,
                        model_max_length=model_max_length,
                        force_short_answer=force_short_answer,
                    )

            random_tomi_vals = []
            for i in range(n_random):
                if alpha == 0:
                    random_tomi_vals.append(base_tomi)
                else:
                    with MLPInterventionEnergyMatched(
                        model, layer_idx=layer_idx, direction=random_dirs[i],
                        alpha=alpha, rms_scale=rms, token_pos=token_pos,
                    ):
                        random_tomi_vals.append(
                            tomi_accuracy_metric(
                                model, tokenizer, tomi_examples,
                                max_new_tokens=max_new_tokens,
                                model_max_length=model_max_length,
                                force_short_answer=force_short_answer,
                            )
                        )

            layer_results[alpha] = {
                "belief": {
                    "tomi": belief_tomi,
                },
                "random": {
                    "tomi": {
                        "mean": float(np.mean(random_tomi_vals)),
                        "std": float(np.std(random_tomi_vals)),
                        "all": random_tomi_vals,
                    },
                },
            }
        results[label] = layer_results
    return results

def plot_delta_by_layer(results, *, metric_key: str, alphas_to_plot=(1.0, -1.0)):
    layers = list(results.keys())
    x = np.arange(len(layers))
    fig, ax = plt.subplots(figsize=(10, 4))
    for alpha in alphas_to_plot:
        if alpha not in results[layers[0]]:
            continue
        belief = [
            results[l][alpha]["belief"][metric_key] - results[l][0]["belief"][metric_key]
            for l in layers
        ]
        rand_mean = [
            results[l][alpha]["random"][metric_key]["mean"] - results[l][0]["random"][metric_key]["mean"]
            for l in layers
        ]
        rand_std = [
            results[l][alpha]["random"][metric_key]["std"]
            for l in layers
        ]
        ax.plot(x, belief, marker="o", label=f"belief α={alpha}")
        ax.errorbar(
            x, rand_mean, yerr=rand_std, fmt="s", capsize=3, label=f"random α={alpha}"
        )
    ax.axhline(0, color="gray", linewidth=1)
    ax.set_xticks(x)
    ax.set_xticklabels(layers, rotation=0)
    ax.set_ylabel(f"Δ {metric_key}")
    ax.set_title(f"Δ {metric_key} by layer (energy-matched)")
    ax.legend()
    plt.tight_layout()
    return fig

In [19]:
# -------------------------
# RUN CONFIG (edit as needed)
# -------------------------
alphas = [-2, -1, -0.5, 0, 0.5, 1, 2]
layers_requested = [0, 3, 7, 11, 15, 17, 21, 25, 28, "last"]
last_layer = model.config.num_hidden_layers - 1
layer_map = [(str(l), l) if isinstance(l, int) else ("last", last_layer) for l in layers_requested]

# Direction v = w17 / ||w17||
belief_direction = normalize_direction(tomi_weights)

# Held-out batch to compute per-layer RMS
rms_by_layer = compute_layer_rms(
    model, tokenizer, tomi, layer_map,
    n_examples=64, batch_size=8, model_max_length=512, force_short_answer=True, seed=0
)
print("RMS by layer:", rms_by_layer)


RMS by layer: {'0': 0.00794114761971683, '3': 0.006114904782236968, '7': 0.014493290291788927, '11': 0.020886550406225095, '15': 0.03127533138171361, '17': 0.043384485102414744, '21': 0.05830353876796163, '25': 0.06435790680427861, '28': 0.10574129887689336, 'last': 0.4696316081875487}


In [None]:
results = run_layer_sweep(
    model, tokenizer, tomi[:10], belief_direction, layer_map, alphas, rms_by_layer,
    n_random=3, seed=0, force_short_answer=True,
    max_new_tokens=8, model_max_length=512, token_pos=None,
 )

In [21]:
# PLOT (uncomment after running sweep)
"""fig = plot_delta_by_layer(results, metric_key="tomi", alphas_to_plot=(1.0, -1.0))"""

'fig = plot_delta_by_layer(results, metric_key="tomi", alphas_to_plot=(1.0, -1.0))'