<a href="https://colab.research.google.com/github/deeptanshukumar/B-PLIS-rag/blob/main/context_focus_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q transformers accelerate datasets sentencepiece einops tqdm


In [2]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
import random


In [3]:
import gc
torch.cuda.empty_cache()
gc.collect()


120

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16  # critical for 7B on Colab

print("Device:", device)


Device: cuda


loading the model

In [5]:
MODEL_NAME = "NousResearch/Llama-2-7b-chat-hf"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token  # important for batching

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=dtype,
    device_map={"": 0},  # force everything on GPU
    output_hidden_states=True,
)

model.eval()

print("Model loaded.")



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/746 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/435 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/583 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

The following generation flags are not valid and may be ignored: ['output_hidden_states']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/200 [00:00<?, ?B/s]

Model loaded.


making sure hidden states are accessible

In [6]:
with torch.no_grad():
    inputs = tokenizer("Hello world", return_tensors="pt").to(device)
    outputs = model(**inputs)

print("Number of layers:", len(outputs.hidden_states))
print("Hidden state shape (layer 0):", outputs.hidden_states[0].shape)
print("Hidden state shape (last layer):", outputs.hidden_states[-1].shape)


Number of layers: 33
Hidden state shape (layer 0): torch.Size([1, 3, 4096])
Hidden state shape (last layer): torch.Size([1, 3, 4096])


system prompt variants (used for vector construction)

In [7]:
SYSTEM_PROMPT_VARIANTS = [
    "You are a context-based QA assistant and must answer based only on the provided context.",
    "As a QA assistant, you are instructed to refer only to the provided context when answering.",
    "Provide answers based solely on the context you are given.",
    "You are a QA assistant and must restrict your answers to the given context.",
    "Answer the question using only the information from the context.",
    "Base your answer strictly on the provided context and nothing else.",
]


prompt builder for the llama model

In [8]:
def build_llama2_chat_prompt(system_prompt: str, user_prompt: str) -> str:
    """
    Builds a LLaMA-2 chat-style prompt.
    """
    return f"""<s>[INST] <<SYS>>
{system_prompt}
<</SYS>>

{user_prompt} [/INST]"""


Positive prompt (system + context + question)

This is the context-conditioned prompt used to generate the positive activation.

In [9]:
def build_positive_prompt(context: str, question: str) -> str:
    system_prompt = random.choice(SYSTEM_PROMPT_VARIANTS)
    user_prompt = f"""Context:
{context}

Question:
{question}"""
    return build_llama2_chat_prompt(system_prompt, user_prompt)


Negative prompt (question only)

This is the parametric-memory prompt (no system instruction, no context).

In [10]:
def build_negative_prompt(question: str) -> str:
    # IMPORTANT: no system instruction, no context
    return f"<s>[INST] {question} [/INST]"


Open-ended generation prompt (evaluation-time)

This is used after steering is applied.

In [11]:
def build_open_ended_prompt(context: str, question: str) -> str:
    system_prompt = "You are a Contextual QA Assistant. Please answer the following question according to the given context. Restrict your response to one sentence."
    user_prompt = f"""Context:
{context}

Question:
{question}"""
    return build_llama2_chat_prompt(system_prompt, user_prompt)


(Optional) O&I prompt (paper baseline)

We won’t use this yet, but we implement it because the paper does.

In [12]:
def build_oi_prompt(context: str, question: str) -> str:
    system_prompt = "You are a Contextual QA Assistant. Please answer according to the given context. Restrict your response to one sentence."
    user_prompt = f"""Bob said, "{context}"

{question} in Bob's opinion?"""
    return build_llama2_chat_prompt(system_prompt, user_prompt)


In [13]:
example_context = "Brian Niccol is the CEO of Starbucks."
example_question = "Who is the CEO of Starbucks?"

print("=== POSITIVE PROMPT ===")
print(build_positive_prompt(example_context, example_question))
print("\n=== NEGATIVE PROMPT ===")
print(build_negative_prompt(example_question))
print("\n=== OPEN-ENDED PROMPT ===")
print(build_open_ended_prompt(example_context, example_question))


=== POSITIVE PROMPT ===
<s>[INST] <<SYS>>
As a QA assistant, you are instructed to refer only to the provided context when answering.
<</SYS>>

Context:
Brian Niccol is the CEO of Starbucks.

Question:
Who is the CEO of Starbucks? [/INST]

=== NEGATIVE PROMPT ===
<s>[INST] Who is the CEO of Starbucks? [/INST]

=== OPEN-ENDED PROMPT ===
<s>[INST] <<SYS>>
You are a Contextual QA Assistant. Please answer the following question according to the given context. Restrict your response to one sentence.
<</SYS>>

Context:
Brian Niccol is the CEO of Starbucks.

Question:
Who is the CEO of Starbucks? [/INST]


# Activation Extraction Utilities

###(Last-token residual stream, all layers)

tokenize + forward pass (hidden states)

In [14]:
def run_model_get_hidden_states(prompt: str):
    """
    Runs the model on a single prompt and returns all hidden states.
    """
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=False,
        truncation=True,
    ).to(device)

    with torch.no_grad():
        outputs = model(**inputs)

    # outputs.hidden_states is a tuple:
    # (embedding layer, layer 1, layer 2, ..., final layer)
    return outputs.hidden_states


Extract last-token residual activation for all layers

In [15]:
def get_last_token_activations(hidden_states):
    """
    Extracts last-token activations from all layers.
    Returns a list of tensors, one per layer.
    """
    last_token_acts = []
    for layer_h in hidden_states:
        # shape: [batch, seq_len, hidden_dim]
        last_token_acts.append(layer_h[:, -1, :].squeeze(0))
    return last_token_acts


Positive vs Negative activation extraction (single example)

In [16]:
def extract_contrastive_activations(context: str, question: str):
    """
    Returns:
      pos_acts: list[layer] -> tensor(hidden_dim)
      neg_acts: list[layer] -> tensor(hidden_dim)
    """
    pos_prompt = build_positive_prompt(context, question)
    neg_prompt = build_negative_prompt(question)

    pos_hidden = run_model_get_hidden_states(pos_prompt)
    neg_hidden = run_model_get_hidden_states(neg_prompt)

    pos_acts = get_last_token_activations(pos_hidden)
    neg_acts = get_last_token_activations(neg_hidden)

    return pos_acts, neg_acts


In [17]:
# checking


test_context = "Brian Niccol is the CEO of Starbucks."
test_question = "Who is the CEO of Starbucks?"

pos_acts, neg_acts = extract_contrastive_activations(
    test_context, test_question
)

print("Number of layers:", len(pos_acts))
print("Activation shape (one layer):", pos_acts[0].shape)


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Number of layers: 33
Activation shape (one layer): torch.Size([4096])


#Steering Vector Computation

loops over a dataset

extracts contrastive activations

computes per-layer steering vectors

averages them across examples

In [18]:
# Minimal conflict-style dataset
steering_data = [
    {
        "context": "Brian Niccol is the CEO of Starbucks.",
        "question": "Who is the CEO of Starbucks?"
    },
    {
        "context": "The Titanic movie score was composed by John Williams.",
        "question": "Who composed the score for Titanic?"
    },
    {
        "context": "Kyle Korver is a center known for his shooting ability.",
        "question": "What position does Kyle Korver play?"
    },
]


In [19]:
def initialize_layer_accumulators(num_layers, hidden_dim, dtype):
    return [
        torch.zeros(hidden_dim, device=device, dtype=dtype)
        for _ in range(num_layers)
    ]


In [20]:
# Compute steering vectors
def compute_steering_vectors(dataset):
    """
    Returns:
      steering_vectors: list[layer] -> tensor(hidden_dim)
    """
    # Run once to get dimensions
    sample_pos, sample_neg = extract_contrastive_activations(
        dataset[0]["context"],
        dataset[0]["question"]
    )

    num_layers = len(sample_pos)
    hidden_dim = sample_pos[0].shape[0]

    accumulators = initialize_layer_accumulators(
    num_layers,
    hidden_dim,
    sample_pos[0].dtype,  # 🔑 match model dtype
)

    count = 0

    for ex in tqdm(dataset, desc="Computing steering vectors"):
        pos_acts, neg_acts = extract_contrastive_activations(
            ex["context"],
            ex["question"]
        )

        for l in range(num_layers):
            accumulators[l] += (pos_acts[l] - neg_acts[l])

        count += 1

    # Average across dataset
    steering_vectors = [acc / count for acc in accumulators]
    return steering_vectors


Run vector construction

In [21]:
steering_vectors = compute_steering_vectors(steering_data)

print("Computed steering vectors for", len(steering_vectors), "layers")
print("Vector shape:", steering_vectors[0].shape)


Computing steering vectors: 100%|██████████| 3/3 [00:00<00:00,  6.41it/s]


Computed steering vectors for 33 layers
Vector shape: torch.Size([4096])


In [22]:
def normalize_vectors(vectors):
    return [v / (v.norm().to(v.dtype) + 1e-8) for v in vectors]

steering_vectors_norm = normalize_vectors(steering_vectors)


In [23]:
print("Normalized steering vectors ready.")
print("Num layers:", len(steering_vectors_norm))
print("Vector dim:", steering_vectors_norm[0].shape)


Normalized steering vectors ready.
Num layers: 33
Vector dim: torch.Size([4096])


In [33]:
# 🔑 IMPORTANT: drop embedding layer (layer 0)
# hidden_states[0] = embeddings, cannot be hooked
steering_vectors_tf = steering_vectors_norm[1:]

# sanity check
assert len(steering_vectors_tf) == len(model.model.layers)

print("Using transformer-layer steering vectors only.")
print("Num transformer layers:", len(steering_vectors_tf))


Using transformer-layer steering vectors only.
Num transformer layers: 32


In [24]:
print("Model dtype:", model.model.layers[0].mlp.gate_proj.weight.dtype)
print("Steering vector dtype:", steering_vectors_norm[0].dtype)



Model dtype: torch.float16
Steering vector dtype: torch.float16


#Layer Sweep + Best Layer Selection

Define the ps metric

ps = whether the context-substituted answer appears in the model output

In [25]:
def compute_ps(output_text: str, substituted_answer: str) -> int:
    """
    Returns 1 if substituted (context-faithful) answer appears in output, else 0.
    """
    return int(substituted_answer.lower() in output_text.lower())


In [26]:
def generate_answer(prompt: str, max_new_tokens: int = 50):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    with torch.no_grad():
        output_ids = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        temperature=0.0,
        use_cache=False,
)


    return tokenizer.decode(output_ids[0], skip_special_tokens=True)


Activation steering hook (single layer)

In [27]:
class SteeringHook:
    def __init__(self, steering_vector, multiplier):
        self.v = steering_vector
        self.m = multiplier

    def __call__(self, module, input, output):
        # output: [batch, seq_len, hidden_dim]
        output[:, -1, :] += self.m * self.v
        return output


Apply steering at a specific layer

In [28]:
def forward_with_steering(prompt, layer_idx, steering_vectors, multiplier=2.0):
    hook = SteeringHook(steering_vectors[layer_idx], multiplier)
    handle = model.model.layers[layer_idx].register_forward_hook(hook)

    try:
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits[:, -1, :]  # last-token logits
    finally:
        handle.remove()

    return logits


#Layer sweep loop



Apply steering at each layer

Generate outputs

Measure ps

Pick best layer

In [29]:
def layer_sweep(dataset, steering_vectors, substituted_answers, multiplier=2.0):
    num_layers = len(steering_vectors)
    layer_ps = []

    for layer in tqdm(range(num_layers), desc="Layer sweep"):
        ps_scores = []

        for ex, sub_ans in zip(dataset, substituted_answers):
            prompt = build_open_ended_prompt(ex["context"], ex["question"])
            logits = forward_with_steering(
                prompt,
                layer_idx=layer,
                steering_vectors=steering_vectors,
                multiplier=multiplier
            )

            # check if substituted answer token is preferred
            sub_ids = tokenizer(sub_ans, add_special_tokens=False).input_ids
            if len(sub_ids) > 0:
                ps_scores.append(
                    int(logits[0, sub_ids[0]].item() > 0)
                )
            else:
                ps_scores.append(0)

        layer_ps.append(np.mean(ps_scores))

        torch.cuda.empty_cache()

    return layer_ps


In [34]:
substituted_answers = [
    "Brian Niccol",
    "John Williams",
    "center",
]

layer_ps_scores = layer_sweep(
    steering_data,
    steering_vectors_tf,
    substituted_answers,
    multiplier=2.0
)



Layer sweep: 100%|██████████| 32/32 [00:09<00:00,  3.53it/s]


#Final Inference-Time Steering + Comparison

Select the best layer (from previous sweep)

In [35]:
best_layer = int(np.argmax(layer_ps_scores))

print("Best transformer layer index:", best_layer)
print("Best ps score:", layer_ps_scores[best_layer])


Best transformer layer index: 1
Best ps score: 1.0


Final, paper-faithful steering hook (last-token only)

In [36]:
class FinalSteeringHook:
    def __init__(self, steering_vector, multiplier):
        self.v = steering_vector
        self.m = multiplier

    def __call__(self, module, input, output):
        # output shape: [batch, seq_len, hidden_dim]
        output[:, -1, :] += self.m * self.v
        return output


Final steered generation function

In [37]:
def generate_with_final_steering(
    prompt: str,
    layer_idx: int,
    steering_vectors,
    multiplier: float = 2.0,
    max_new_tokens: int = 50,
):
    hook = FinalSteeringHook(steering_vectors[layer_idx], multiplier)
    handle = model.model.layers[layer_idx].register_forward_hook(hook)

    try:
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                temperature=0.0,
                use_cache=False,  # important with hooks
            )
        output_text = tokenizer.decode(
            output_ids[0], skip_special_tokens=True
        )
    finally:
        handle.remove()

    return output_text


Baseline generation

In [38]:
def generate_baseline(prompt: str, max_new_tokens: int = 50):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=0.0,
            use_cache=False,
        )
    return tokenizer.decode(output_ids[0], skip_special_tokens=True)


##Side-by-side comparison

In [39]:
for ex in steering_data:
    prompt = build_open_ended_prompt(ex["context"], ex["question"])

    base_out = generate_baseline(prompt)
    steered_out = generate_with_final_steering(
        prompt,
        layer_idx=best_layer,
        steering_vectors=steering_vectors_tf,
        multiplier=2.0,
    )

    print("=" * 80)
    print("QUESTION:", ex["question"])
    print("CONTEXT:", ex["context"])
    print("\nBASE OUTPUT:")
    print(base_out)
    print("\nSTEERED OUTPUT (ContextFocus):")
    print(steered_out)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


QUESTION: Who is the CEO of Starbucks?
CONTEXT: Brian Niccol is the CEO of Starbucks.

BASE OUTPUT:
[INST] <<SYS>>
You are a Contextual QA Assistant. Please answer the following question according to the given context. Restrict your response to one sentence.
<</SYS>>

Context:
Brian Niccol is the CEO of Starbucks.

Question:
Who is the CEO of Starbucks? [/INST]  Brian Niccol is the CEO of Starbucks.

STEERED OUTPUT (ContextFocus):
[INST] <<SYS>>
You are a Contextual QA Assistant. Please answer the following question according to the given context. Restrict your response to one sentence.
<</SYS>>

Context:
Brian Niccol is the CEO of Starbucks.

Question:
Who is the CEO of Starbucks? [/INST]  Brian Niccol is the CEO of Starbucks.
QUESTION: Who composed the score for Titanic?
CONTEXT: The Titanic movie score was composed by John Williams.

BASE OUTPUT:
[INST] <<SYS>>
You are a Contextual QA Assistant. Please answer the following question according to the given context. Restrict your respo