In [None]:
%pip install transformers accelerate torch sentencepiece --quiet


In [1]:
import os
import json
import math
from dataclasses import dataclass
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [2]:
model_name = "meta-llama/Llama-2-7b-chat-hf"

# I/O paths
queries_path = "../data/query_data.jsonl"
os.makedirs("../outputs", exist_ok=True)
model_stub = model_name.split("/")[-1].replace(":", "_")
out_path = f"../outputs/{model_stub}_preds.jsonl"

# Decoding config (as requested)
gen_cfg = dict(
    max_new_tokens=512,
    temperature=0.2,
    do_sample=True,
    top_k=60,
    top_p=0.9,
    num_beams=1,
    repetition_penalty=1.8,
)

# Safety defaults
device = "cuda" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
print(f"Using device={device}, dtype={torch_dtype}, writing to {out_path}")


Using device=cpu, dtype=torch.float32, writing to ../outputs/Llama-2-7b-chat-hf_preds.jsonl


In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True, trust_remote_code=True)

# Some chat models don't have an explicit pad token; fall back to EOS for batching/padding
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch_dtype,
    device_map="auto",                 # uses all available GPUs/CPU
    trust_remote_code=True,
)

model.eval()
print("Model & tokenizer loaded.")

In [None]:
# Load the prepared prompts (context + adversarial instruction)
queries = [json.loads(l) for l in open(queries_path, "r", encoding="utf-8")]
print(f"Loaded {len(queries)} prompts from {queries_path}")

written = 0
with open(out_path, "w", encoding="utf-8") as fout:
    for ex in tqdm(queries, desc=f"Generating with {model_stub}"):
        prompt = ex["prompt"]
        # Tokenize with safe truncation (keep room for generation)
        inputs = tokenizer(prompt, return_tensors="pt", truncation=False)

        # Move to device
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=gen_cfg["max_new_tokens"],
                do_sample=gen_cfg["do_sample"],
                temperature=gen_cfg["temperature"],
                top_k=gen_cfg["top_k"],
                top_p=gen_cfg["top_p"],
                num_beams=gen_cfg["num_beams"],
                repetition_penalty=gen_cfg["repetition_penalty"],
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
            )

        decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Save one JSON line per example
        record = {
            "model": model_name,
            "anchor": ex["anchor"],
            "context": ex["context"],
            "prompt": prompt,
            "output": decoded,
        }
        fout.write(json.dumps(record, ensure_ascii=False) + "\n")
        written += 1

print(f"Saved {written} generations to {out_path}")

In [None]:
shown = 0
for line in open(out_path, "r", encoding="utf-8"):
    obj = json.loads(line)
    print("———")
    print(obj["output"][:800])  # first 800 chars
    shown += 1
    if shown >= 2:
        break