In [1]:
import torch
from transformer_lens import HookedTransformer
from torch.utils.data import DataLoader, Dataset



In [2]:
def collate(batch):
    pos_prompts, answers, prompts = zip(*batch)
    return {"pos_prompts": list(pos_prompts), "prompts": list(prompts), "answers": list(answers)}

In [3]:
gpt = HookedTransformer.from_pretrained("gemma-2b", center_unembed=True, center_writing_weights=True, fold_ln=True)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [4]:
import json
import torch
from torch.utils.data import Dataset, DataLoader

# === Few-shot template ===
FEWSHOT_TEMPLATE = """Sentence: Apple announced a new iPhone during its annual product launch event.,
POS tag: PRODUCT,
Answer: iPhone

Sentence: Barack Obama delivered a keynote speech at the conference.,
POS tag: PERSON,
Answer: Barack Obama

Sentence: Tesla invested over 2 billion dollars in a new gigafactory in Germany.,
POS tag: MONEY,
Answer: 2 billion dollars

Sentence: The concert will take place at 8 p.m. on Saturday.,
POS tag: TIME,
Answer: 8 p.m. on Saturday

Sentence: The Eiffel Tower is located in Paris.,
POS tag: LOCATION,
Answer: Paris

Sentence: The Olympic Games in Tokyo attracted thousands of visitors despite the pandemic.,
POS tag: EVENT, 
Answer: Olympic Games

Sentence: The recipe calls for 200 grams of sugar and 3 eggs.,
POS tag: NUMERICAL, 
Answer: 200 grams

Sentence: Google has opened a new research center in Zurich to focus on AI development.,
POS tag: ORGANIZATION, 
Answer: Google

Sentence: The American have a long history of culinary excellence.,
POS tag: NATIONALITY, RELIGIOUS, or POLITICAL GROUP, 
Answer: American

Sentence: The Islam religion has over a billion followers worldwide.,
POS tag: NATIONALITY, RELIGIOUS, or POLITICAL GROUP, 
Answer: Islam

Sentence: {sentence}
POS tag: {tag}
Answer:"""

# === Dataset ===
class NERDataset(Dataset):
    def __init__(self, json_path, template=FEWSHOT_TEMPLATE):
        with open(json_path, "r", encoding="utf-8") as f:
            self.data = json.load(f)
        self.template = template

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data[idx]
        sentence = row["Sentence"]
        tag = row["POS tag"]
        answer = row["Answer"]
        prompt = self.template.format(sentence=sentence, tag=tag)
        # Include raw fields so we can write them out on success
        return {"prompt": prompt, "answer": str(answer), "sentence": sentence, "tag": tag}

# === collate ===
def collate(batch):
    prompts = [b["prompt"] for b in batch]
    answers = [b["answer"] for b in batch]
    sentences = [b["sentence"] for b in batch]
    tags = [b["tag"] for b in batch]
    return {"prompts": prompts, "answers": answers, "sentences": sentences, "tags": tags}

# === evaluation ===
@torch.no_grad()
def next_token_accuracy(model, dataset, batch_size=16, print_info=True, save_path="/Users/merve/Desktop/last_code_path_patch/notebooks/dataset/ner_dataset_15each.json"):
    """
    Computes next-token accuracy. Additionally, if `save_path` is provided,
    writes a JSON file containing only CORRECT items in the format:
      [{"Sentence": ..., "POS tag": ..., "Answer": <prediction>}, ...]
    where Answer is the model's predicted *first token*.
    """
    device = getattr(model, "device", "cuda:1")
    model.to(device).eval()

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate)

    total, correct = 0, 0
    correct_records = []

    def decode_token_ids(token_id):
        if hasattr(model, "tokenizer") and hasattr(model.tokenizer, "decode"):
            return model.tokenizer.decode([int(token_id)])
        # Fallback for HookedTransformer
        return model.to_string(torch.tensor([[int(token_id)]], device=device))

    for batch in loader:
        prompts, answers = batch["prompts"], batch["answers"]
        sentences, tags = batch["sentences"], batch["tags"]

        # tokenize each prompt
        prompt_tok_list = [model.to_tokens(p, prepend_bos=True) for p in prompts]
        prompt_lens = [t.size(1) for t in prompt_tok_list]
        pad_id = (
            model.tokenizer.pad_token_id
            if hasattr(model, "tokenizer") and hasattr(model.tokenizer, "pad_token_id") and model.tokenizer.pad_token_id is not None
            else (model.tokenizer.eot if hasattr(model, "tokenizer") and hasattr(model.tokenizer, "eot") else 0)
        )
        toks = torch.nn.utils.rnn.pad_sequence(
            [t.squeeze(0) for t in prompt_tok_list],
            batch_first=True,
            padding_value=pad_id
        ).to(device)

        # forward
        logits = model(toks)  # [B, T_max, V]
        last_idx = torch.tensor([L - 1 for L in prompt_lens], device=toks.device)  # [B]
        next_logits = logits.gather(1, last_idx.view(-1,1,1).expand(-1,1,logits.size(-1))).squeeze(1)  # [B, V]
        pred_ids = next_logits.argmax(dim=-1)  # [B]

        # gold ids
        gold_ids = []
        for p, a, L in zip(prompts, answers, prompt_lens):
            both = model.to_tokens(p + " " + a, prepend_bos=True)
            gold_ids.append(int(both[0, L].item()))
        gold_ids = torch.tensor(gold_ids, device=pred_ids.device)

        # decode for inspection
        pred_texts = [decode_token_ids(t).strip() for t in pred_ids.tolist()]
        gold_texts = [decode_token_ids(t).strip() for t in gold_ids.tolist()]

        # tally
        matches = (pred_ids == gold_ids)
        correct += matches.sum().item()
        total += pred_ids.numel()

        # collect correct items for JSON dump
        for i, is_ok in enumerate(matches.tolist()):
            if is_ok:
                # Write the *prediction* as the Answer
                correct_records.append({
                    "Sentence": sentences[i],
                    "POS tag": tags[i],
                    "Answer": pred_texts[i]  # predicted first token
                })

        if print_info:
            TOPK = 5
            topk_vals, topk_ids = next_logits.topk(TOPK, dim=-1)  # [B, K]
            topk_texts = [[decode_token_ids(token_id).strip() for token_id in row.tolist()] for row in topk_ids]
            print(len(prompts))
            for i in range(len(prompts)):
                print(i)
                print("—" * 48)
                print(f"Prompt: ", prompts[i])
                print(f"Gold answer: {answers[i]}")
                print(f"Gold first token: {gold_texts[i]!r}")
                print(f"Pred first token: {pred_texts[i]!r}")
                print(f"Top-{TOPK}: {topk_texts[i]}")

    acc = 0.0 if total == 0 else 100.0 * correct / total
    print(f"\n✅ Next-token accuracy: {correct}/{total} = {acc:.2f}%")

    torch.cuda.empty_cache()
    del toks, logits

    # Write JSON of correct predictions if requested
    # if save_path is not None:
    #     with open(save_path, "w", encoding="utf-8") as f:
    #         json.dump(correct_records, f, ensure_ascii=False, indent=2)
    #     print(f"💾 Wrote {len(correct_records)} correct items to {save_path}")

    return acc

In [None]:
JSON_PATH = "../../pos_cf_datasets/ner_dataset_15each.json"
ds = NERDataset(JSON_PATH)
next_token_accuracy(gpt, ds, batch_size=16)

Moving model to device:  cuda:1
16
0
————————————————————————————————————————————————
Prompt:  Sentence: Apple announced a new iPhone during its annual product launch event.,
POS tag: PRODUCT,
Answer: iPhone

Sentence: Barack Obama delivered a keynote speech at the conference.,
POS tag: PERSON,
Answer: Barack Obama

Sentence: Tesla invested over 2 billion dollars in a new gigafactory in Germany.,
POS tag: MONEY,
Answer: 2 billion dollars

Sentence: The concert will take place at 8 p.m. on Saturday.,
POS tag: TIME,
Answer: 8 p.m. on Saturday

Sentence: The Eiffel Tower is located in Paris.,
POS tag: LOCATION,
Answer: Paris

Sentence: The Olympic Games in Tokyo attracted thousands of visitors despite the pandemic.,
POS tag: EVENT, 
Answer: Olympic Games

Sentence: The recipe calls for 200 grams of sugar and 3 eggs.,
POS tag: NUMERICAL, 
Answer: 200 grams

Sentence: Google has opened a new research center in Zurich to focus on AI development.,
POS tag: ORGANIZATION, 
Answer: Google

Sente

52.5

: 