In [None]:
import torch
import json
import ast
import os
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm import tqdm

class GradientSteering:
    def __init__(self, model_name="google/flan-t5-small"):
        print(f"Loading {model_name}...")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        self.model.to(self.device)
        self.model.eval()
        self.baseline_cache = {}
        self.baseline_text_cache = {}

    def load_data(self, jsonl_content):
        data = []
        for line in jsonl_content.strip().split('\n'):
            if not line.strip(): continue
            try:
                row = json.loads(line)
                if isinstance(row['classes'], str): classes = ast.literal_eval(row['classes'])
                else: classes = row['classes']

                target_str = classes[row['answer_index']].strip(" .")
                trap_str = classes[1 - row['answer_index']].strip(" .")

                tgt_id = self.tokenizer(" " + target_str).input_ids[0]
                trap_id = self.tokenizer(" " + trap_str).input_ids[0]

                data.append({
                    "prompt": row['prompt'],
                    "target_str": target_str,
                    "target_id": tgt_id,
                    "trap_id": trap_id
                })
            except: continue
        return data

    def mine_gradient_vector(self, train_data, layer_idx):
        print(f"\n--- Mining Gradients (Layer {layer_idx}) ---")
        gradients = []

        for item in tqdm(train_data[:100], desc="Mining"):

            input_ids = self.tokenizer(item['prompt'], return_tensors="pt").input_ids.to(self.device)
            dec_ids = self.tokenizer("<pad>", return_tensors="pt", add_special_tokens=False).input_ids.to(self.device)

            # Forward pass- enable grad
            with torch.enable_grad():
                outputs = self.model(input_ids=input_ids, decoder_input_ids=dec_ids, output_hidden_states=True)
                hidden = outputs.decoder_hidden_states[layer_idx]
                hidden.retain_grad()

                logits = outputs.logits

                # Contrastive objective: log(P_trap) - log(P_target)
                target_logit = logits[0, -1, item['target_id']]
                trap_logit = logits[0, -1, item['trap_id']]
                loss = trap_logit - target_logit

                self.model.zero_grad()
                loss.backward(retain_graph=True)

                # Extract gradient (invert as loss is trap-target)
                grad_vec = -hidden.grad.detach() # [1, 1, hidden]
                gradients.append(grad_vec.squeeze())

        # average, normalize
        global_vec = torch.stack(gradients).mean(dim=0)
        global_vec = global_vec / global_vec.norm()

        print(f"Gradient Vector Extracted. Norm: {global_vec.norm().item():.4f}")
        return global_vec

    def test(self, test_data, vector, layer_idx, coeff, num_samples=-1):
        print(f"\n--- Testing Layer {layer_idx} | Strength {coeff} ---")

        # steering
        def hook(module, input, output):
            hidden = output[0]
            steered = hidden + (coeff * vector.to(hidden.device).to(hidden.dtype))
            return (steered,) + output[1:]

        # mining subset
        start_idx = 100
        end_idx = len(test_data) if num_samples == -1 else 100 + num_samples
        subset = test_data[start_idx:end_idx]
        print(f"Eval on {len(subset)} items...")

        cache_key = (start_idx, end_idx)

        if not hasattr(self, "baseline_cache"):
            self.baseline_cache = {}
        if not hasattr(self, "baseline_text_cache"):
            self.baseline_text_cache = {}

        # Calculate baseline
        if cache_key in self.baseline_cache:
            base_success_count = self.baseline_cache[cache_key]
            baseline_texts = self.baseline_text_cache[cache_key]

            print(f"Using cached baseline: {base_success_count}")

        else:
            base_success_count = 0
            baseline_texts = []

            for item in tqdm(subset, desc="Base"):
                inp = self.tokenizer(item["prompt"], return_tensors="pt").input_ids.to(self.device)
                out = self.model.generate(inp, max_new_tokens=20)
                txt = self.tokenizer.decode(out[0], skip_special_tokens=True)
                baseline_texts.append(txt)

                if item["target_str"].lower() in txt.lower():
                    base_success_count += 1

            self.baseline_cache[cache_key] = base_success_count
            self.baseline_text_cache[cache_key] = baseline_texts

        # Calculate steered
        success = 0
        steered_texts = []

        h = self.model.decoder.block[layer_idx].register_forward_hook(hook)
        try:
            for item in tqdm(subset, desc="Steered"):
                inp = self.tokenizer(item["prompt"], return_tensors="pt").input_ids.to(self.device)
                out = self.model.generate(inp, max_new_tokens=20)
                txt = self.tokenizer.decode(out[0], skip_special_tokens=True)
                steered_texts.append(txt)

                if item["target_str"].lower() in txt.lower():
                    success += 1
        finally:
            h.remove()

        base_acc = base_success_count / len(subset)
        steer_acc = success / len(subset)

        print()
        print(f"Base: {base_acc:.1%} | Steered: {steer_acc:.1%}")
        print(f"Improvement:       {steer_acc - base_acc:.1%}")

        # Results
        results = []
        for item, base_txt, steer_txt in zip(subset, baseline_texts, steered_texts):
            results.append({
                "prompt": item["prompt"],
                "target": item["target_str"],
                "baseline_gen": base_txt,
                "steered_gen": steer_txt,
                "baseline_correct": item["target_str"].lower() in base_txt.lower(),
                "steered_correct": item["target_str"].lower() in steer_txt.lower(),
            })

        return {
            "layer": layer_idx,
            "coeff": coeff,
            "base_acc": base_acc,
            "steer_acc": steer_acc,
            "improvement": steer_acc - base_acc,
            "results": results,
        }

def visualize_steering_result(
    obj,
    max_prompt_len=70,
    max_gen_len=18,
    only_improvements=False,
):

    def trunc(s, n):
        if s is None:
            return ""
        return s if len(s) <= n else s[: n - 3] + "..."

    print("=" * 100)
    print(
        f"Layer: {obj['layer']} | "
        f"Coeff: {obj['coeff']} | "
        f"Baseline Acc: {obj['base_acc']:.3f} | "
        f"Steered Acc: {obj['steer_acc']:.3f} | "
        f"Δ: {obj['improvement']:.3f}"
    )
    print("=" * 100)

    header = f"{'B→S':<4} {'TARGET':<10} {'BASE':<18} {'STEER':<18} PROMPT"
    print(header)
    print("-" * len(header))

    shown = 0
    for r in obj["results"]:
        base_is_ok = r.get("baseline_correct", False)
        steered_is_ok = r.get("steered_correct", False)
        
        if only_improvements and (base_is_ok or not steered_is_ok):
            continue

        status = (
            "✓✓" if base_is_ok and steered_is_ok else
            "✗✓" if (not base_is_ok and steered_is_ok) else
            "✓✗" if (base_is_ok and not steered_is_ok) else
            "✗✗"
        )

        print(
            f"{status:<4} "
            f"{r.get('target',''):<10} "
            f"{trunc(r.get('baseline_gen',''), max_gen_len):<18} "
            f"{trunc(r.get('steered_gen',''), max_gen_len):<18} "
            f"{trunc(r.get('prompt',''), max_prompt_len)}"
        )
        shown += 1

    if shown == 0:
        print("(no rows to display)")

if __name__ == "__main__":
    print("GradientSteering helper class successfully loaded!")


In [None]:
FILENAME = "data/memo-trap-input-data.jsonl"
ACTIVATION_STRENGTH = 1000
NUM_SAMPLES = 300
LAYER = 5

# Runs on L4 on colab
if os.path.exists(FILENAME):
    with open(FILENAME, "r") as f:
        content = f.read()

    runner = GradientSteering(model_name="google/flan-t5-base")
    data = runner.load_data(content)

    print(f"Mining gradients on Layer {LAYER}!")
    vec = runner.mine_gradient_vector(data, LAYER)
    print(f"Testing on {NUM_SAMPLES} new samples!")
    results = runner.test(data, vec, LAYER, ACTIVATION_STRENGTH, NUM_SAMPLES)
else:
    print(f"File not found: {FILENAME}")

In [None]:
# Can change only_improvements to False if you want to see all
visualize_steering_result(results, only_improvements=True)