You can try your own prompt with differnet model here.

In [None]:
from pathlib import Path
import os
import numpy as np

# HuggingFace model name (or local path)
model_name = "meta-llama/Llama-2-7b-hf"

# Hidden stats json path (must contain obj["summary"])
stats_json_path = Path("flores_hidden_stats_llama2_7b_50.json")  # change if needed

# Device / dtype
device = "cuda"   # "cuda" or "cpu"
dtype  = "float16"  # "float16" | "bfloat16" | "float32"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

print("model_name:", model_name)
print("stats_json_path:", stats_json_path.resolve())
print("device:", device, "| dtype:", dtype)

In [None]:
# Loading model and wrapping
import json
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

from repeng import ControlVector, ControlModel
from repeng.control import ControlModule

def resolve_dtype(dtype: str) -> torch.dtype:
    if dtype == "float16":
        return torch.float16
    if dtype == "bfloat16":
        return torch.bfloat16
    return torch.float32

def load_summary(stats_json_path: Path):
    with stats_json_path.open("r", encoding="utf-8") as f:
        obj = json.load(f)
    return obj["summary"]

def patch_control_forward_overwrite_nonzero():
    """
    Overwrite ControlModule.forward with overwrite-nonzero behavior.
    Uses ControlModule.custom_coeff as a scalar multiplier.
    """
    ControlModule.custom_coeff = 1.0

    def normalized_forward(self, *args, **kwargs):
        output = self.block(*args, **kwargs)
        control = self.params.control

        if control is None:
            return output
        if len(control.shape) == 1:
            control = control.reshape(1, 1, -1)
        if torch.all(control == 0):
            return output

        modified = output[0] if isinstance(output, tuple) else output
        control = control.to(modified.device)

        norm_pre = torch.norm(modified, dim=-1, keepdim=True)

        # Padding mask handling (optional)
        if "position_ids" in kwargs:
            pos = kwargs["position_ids"]
            zero_idx = (pos == 0).cumsum(1).argmax(1, keepdim=True)
            col = torch.arange(pos.size(1), device=pos.device).unsqueeze(0)
            mask = (col >= zero_idx).float().reshape(modified.shape[0], modified.shape[1], 1)
            mask = mask.to(modified.dtype).to(modified.device)
        else:
            mask = 1.0

        control_applied = control * float(ControlModule.custom_coeff) * mask
        modified = torch.where(control_applied != 0, control_applied, modified)

        if self.params.normalize:
            norm_post = torch.norm(modified, dim=-1, keepdim=True)
            modified = modified / norm_post * norm_pre

        if isinstance(output, tuple):
            return (modified,) + output[1:]
        return modified

    ControlModule.forward = normalized_forward

# ---- Load summary ----
summary = load_summary(stats_json_path)

# ---- Load tokenizer/model ----
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

base_model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    device_map="auto" if device.startswith("cuda") else None,
    torch_dtype=resolve_dtype(dtype),
)
base_model.eval()

if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token_id = tokenizer.eos_token_id

# ---- Patch repeng forward ----
patch_control_forward_overwrite_nonzero()

# ---- Wrap model with ControlModel ----
num_layers = len(base_model.model.layers)
last_layer_id = num_layers - 1

wrap_start_layer = 6  # you can keep this fixed for playground
wrapped_layers = list(range(wrap_start_layer, num_layers))
model = ControlModel(base_model, wrapped_layers)

print("Loaded model:", model_name)
print("num_layers:", num_layers, "| last_layer_id:", last_layer_id)
print("wrapped layers:", wrapped_layers[0], "...", wrapped_layers[-1])

In [None]:
# Try your prompt and settings
# Setting: "para" or "mono"
setting = "para"
anchor_layer = 19  # only used when setting == "mono" (1-based)


target_lang = "zh"         # target language for EN->X
k_dims = 400               # top-K dims

# Intervention layer (1-based)
intervention_layer_1based = 19
# Strength list
strength_list = [0.4, 0.8, 1.2]

# Decoding params
max_new_tokens = 64

do_sample = False
temperature = 0.0
top_p = 0.9
repetition_penalty = 1.1

# Prompt
prompt = "Translate an English sentence into target language.\nEnglish: I love Kyoto in winter.\nTarget language: "

print("setting:", setting, "| target_lang:", target_lang, "| k_dims:", k_dims)
print("intervention_layer_1based:", intervention_layer_1based, "| strengths:", strength_list)
print("do_sample:", do_sample, "| temperature:", temperature, "| top_p:", top_p, "| max_new_tokens:", max_new_tokens)
print("prompt:", prompt)

In [None]:
import numpy as np
import torch
from repeng import ControlVector
from repeng.control import ControlModule

anchor_layer-=1
def apply_model_specific_zeroing(diff: np.ndarray, model_name: str):
    # keep your special handling if needed; here only llama2-7b / 13b examples
    if "Llama-2-7b" in model_name or "Llama-2-7b-hf" in model_name:
        if diff.shape[0] > 1415: diff[1415] = 0.0
        if diff.shape[0] > 2533: diff[2533] = 0.0
    if "Llama-2-13b" in model_name or "Llama-2-13b-hf" in model_name:
        if diff.shape[0] > 2100: diff[2100] = 0.0

def build_control_vec_en2x(summary, setting, lang, k_dims, last_layer_id, anchor_layer):
    """
    EN->X:
      para: source=en@last, target=lang@last
      mono: source=lang@anchor, target=lang@last
    Use top-K dims by |source-target|, and set vector to target_mean on those dims.
    """
    tgt_key = f"layer_{last_layer_id}"
    target_mean = np.asarray(summary[lang][tgt_key]["mean"], dtype=np.float32)

    if setting == "mono":
        src_key = f"layer_{anchor_layer}"
        source_mean = np.asarray(summary[lang][src_key]["mean"], dtype=np.float32)
    else:
        source_mean = np.asarray(summary["en"][tgt_key]["mean"], dtype=np.float32)

    diff = source_mean - target_mean
    apply_model_specific_zeroing(diff, model_name)

    topk_idx = np.argsort(-np.abs(diff))[:k_dims]
    control_vec = np.zeros_like(target_mean, dtype=np.float32)
    control_vec[topk_idx] = target_mean[topk_idx]
    return control_vec

def truncate_at_stop(s: str) -> str:
    idx = s.find("\n")
    return s if idx == -1 else s[:idx]

# ----- Validate layer (1-based -> 0-based) -----
layer_idx = int(intervention_layer_1based) - 1
if not (0 <= layer_idx < num_layers):
    raise ValueError(f"intervention_layer_1based must be in [1, {num_layers}], got {intervention_layer_1based}")

# ----- Build control vector for target_lang -----
control_vec = build_control_vec_en2x(
    summary=summary,
    setting=setting,
    lang=target_lang,
    k_dims=k_dims,
    last_layer_id=last_layer_id,
    anchor_layer=anchor_layer,
)

# ----- Decode kwargs (avoid warnings if greedy) -----
gen_kwargs = dict(
    do_sample=do_sample,
    repetition_penalty=float(repetition_penalty),
    pad_token_id=tokenizer.eos_token_id,
    max_new_tokens=int(max_new_tokens),
)
if do_sample:
    gen_kwargs["temperature"] = float(temperature)
    gen_kwargs["top_p"] = float(top_p)

# ----- Prepare prompt -----
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}

print(f"Prompt:\n{prompt}\n")

# ----- Run for each strength -----
for strength in strength_list:
    # Apply control at exactly one layer
    layer_directions = {
        lid: (control_vec if lid == layer_idx else np.zeros_like(control_vec))
        for lid in range(num_layers)
    }
    control_vector = ControlVector(model_type="llama", directions=layer_directions)

    model.reset()
    model.set_control(control_vector, coeff=1.0)

    ControlModule.custom_coeff = float(strength)

    out = model.generate(**inputs, **gen_kwargs)
    gen_tokens = out[0][inputs["input_ids"].shape[-1]:]
    text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
    text = truncate_at_stop(text).strip()

    print(f"Intervention layer: {intervention_layer_1based} | Strength: {float(strength):.2f}")
    print(f"Response: {text}\n")

    if device.startswith("cuda"):
        torch.cuda.empty_cache()