# Zero-Shot Citation Hallucination Detection
Full pipeline: **Generate → Verify → Score → Visualize**

Adapts the FACTUM framework for zero-shot (non-RAG) citation hallucination detection.
- **ICS** (Internal Consistency Score): cosine sim between citation hidden state and mean essay context
- **POS** (Pathway Orthogonality Score): 1 - |cos(pre_layer, post_layer)| at citation token
- **PFS** (Pathway Fluctuation Score): L2 norm of the hidden state update
- **BAS** (BOS Attention Score): attention from citation token to BOS (position 0)

> **Note:** FACTUM-main code is NOT used at runtime — all scoring is self-contained here.

---
## 1. Setup

In [None]:
!pip install -q torch transformers accelerate requests tqdm
import torch, gc
assert torch.cuda.is_available(), 'Switch to GPU runtime!'
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

In [None]:
from huggingface_hub import login
login(token="")

In [None]:
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
MAX_PROMPTS = 5
MAX_NEW_TOKENS = 1024
TEMPERATURE = 0.7
SCORE_MAX_LEN = 512  # Max tokens for scoring (fits T4 VRAM)

---
## 2. Prompts

In [None]:
PROMPTS = [
    {"id": "sci_001", "domain": "scientific", "prompt": "Write a 500-word essay explaining the Transformer architecture and its impact on natural language processing. Include at least 3 citations to foundational papers with author names and publication years."},
    {"id": "sci_002", "domain": "scientific", "prompt": "Discuss the development and applications of CRISPR-Cas9 gene editing technology. Cite at least 3 seminal papers that contributed to this field."},
    {"id": "sci_003", "domain": "scientific", "prompt": "Explain how deep reinforcement learning achieved superhuman performance in games like Go and Chess. Reference the key papers with proper citations."},
    {"id": "sci_004", "domain": "scientific", "prompt": "Describe the methodology and significance of the AlphaFold protein structure prediction system. Include citations to the original research papers."},
    {"id": "sci_005", "domain": "scientific", "prompt": "Write about the discovery and implications of the Higgs boson. Cite the foundational theoretical papers and experimental confirmation studies."},
]
print(f"{len(PROMPTS)} prompts loaded")

---
## 3. Citation Extractor

In [None]:
import re
from dataclasses import dataclass, asdict, field
from typing import List, Optional

@dataclass
class Citation:
    raw_text: str
    start_pos: int
    end_pos: int
    extracted_authors: Optional[List[str]] = None
    extracted_year: Optional[int] = None
    extracted_title: Optional[str] = None
    citation_type: str = "academic"

CITATION_PATTERNS = [
    re.compile(r'\(([A-Z][a-z]+(?:\s+(?:et\s+al\.?|&|and)\s+[A-Z][a-z]+)*,?\s*\d{4}[a-z]?)\)'),
    re.compile(r'([A-Z][a-z]+(?:\s+et\s+al\.?))\s*\((\d{4}[a-z]?)\)'),
    re.compile(r'([A-Z][a-z]+(?:\s+(?:and|&)\s+[A-Z][a-z]+)?)\s*\((\d{4})\)'),
    re.compile(r'"([^"]{10,100})"\s*\((\d{4})\)'),
]

def extract_citations(text):
    citations, seen = [], set()
    for pat in CITATION_PATTERNS:
        for m in pat.finditer(text):
            s, e = m.span()
            if any(s <= p <= e for p in seen): continue
            seen.update(range(s, e + 1))
            c = Citation(raw_text=m.group(0), start_pos=s, end_pos=e)
            yr = re.search(r'\d{4}', m.group(0))
            if yr: c.extracted_year = int(yr.group())
            au = re.search(r'([A-Z][a-z]+)', m.group(0))
            if au: c.extracted_authors = [au.group(1)]
            citations.append(c)
    return citations

# Quick test
test = extract_citations('Vaswani et al. (2017) introduced Transformers. See also (Devlin et al., 2019).')
print(f'Test: {len(test)} citations — {", ".join(c.raw_text for c in test)}')

---
## 4. Generate Essays

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import json

print(f'Loading {MODEL_NAME}...')
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map='auto')
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
print(f'Loaded. {model.config.num_hidden_layers}L, {model.config.num_attention_heads}H')

In [None]:
SYSTEM_PROMPT = """You are an academic expert writing for a peer-reviewed journal.
You MUST include proper academic citations in the format: Author (Year) or (Author et al., Year).
Every factual claim must be supported by a citation. Use real author names and years.
Example: Vaswani et al. (2017) introduced the Transformer architecture."""

essays = []
for p in tqdm(PROMPTS[:MAX_PROMPTS], desc='Generating'):
    messages = [{"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": p["prompt"]}]
    try:
        formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    except Exception:
        formatted = f"{SYSTEM_PROMPT}\n\nRequest: {p['prompt']}\n\nResponse:"

    inputs = tokenizer(formatted, return_tensors='pt', truncation=True, max_length=2048).to('cuda')
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, temperature=TEMPERATURE,
                             top_p=0.9, do_sample=True, pad_token_id=tokenizer.pad_token_id)
    response = tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
    cites = extract_citations(response)
    essays.append({'prompt_id': p['id'], 'domain': p['domain'], 'prompt': p['prompt'],
                   'model_name': MODEL_NAME, 'response': response, 'citations': [asdict(c) for c in cites]})
    print(f"  {p['id']}: {len(response)} chars, {len(cites)} citations")

total_c = sum(len(e['citations']) for e in essays)
print(f'\nDone: {len(essays)} essays, {total_c} total citations')

In [None]:
# Preview
for e in essays:
    print(f"\n=== {e['prompt_id']} ({len(e['citations'])} cites) ===")
    print(e['response'][:500])
    for c in e['citations'][:3]:
        print(f"  → {c['raw_text']}")

In [None]:
# Save essays & free generation model
import os
os.makedirs('results', exist_ok=True)
with open('results/essays.json', 'w') as f: json.dump(essays, f, indent=2)
print(f'Saved {len(essays)} essays to results/essays.json')

del model
gc.collect(); torch.cuda.empty_cache()
print(f'VRAM after cleanup: {torch.cuda.memory_allocated()/1e9:.2f} GB')

---
## 5. Citation Verification (Semantic Scholar)

In [None]:
import time, requests

def verify_citation(authors, year, title=None):
    parts = []
    if authors: parts.append(authors[0])
    if year: parts.append(str(year))
    if title: parts.append(title[:50])
    if not parts: return 'unverified', 0.0, None
    try:
        r = requests.get('https://api.semanticscholar.org/graph/v1/paper/search',
                         params={'query': ' '.join(parts), 'limit': 5, 'fields': 'title,authors,year'}, timeout=10)
        time.sleep(1)
        if r.status_code != 200: return 'unverified', 0.0, None
        papers = r.json().get('data', [])
        if not papers: return 'fabricated', 0.7, None
        for p in papers:
            yr_ok = (year is None) or (p.get('year') == year)
            au_ok = not authors or any(authors[0].split()[-1].lower() in a.get('name','').lower() for a in p.get('authors',[]))
            if yr_ok and au_ok: return 'real', 0.85, p.get('title')
        return 'fabricated', 0.6, None
    except Exception: return 'unverified', 0.0, None

stats = {'real': 0, 'fabricated': 0, 'unverified': 0}
for essay in tqdm(essays, desc='Verifying'):
    for c in essay['citations']:
        label, conf, matched = verify_citation(c.get('extracted_authors'), c.get('extracted_year'), c.get('extracted_title'))
        c['label'] = label; c['confidence'] = conf; c['matched_title'] = matched
        stats[label] += 1

total = sum(stats.values())
print(f'\nVerification ({total} citations):')
for k, v in stats.items():
    print(f'  {k}: {v} ({v/total*100:.0f}%)' if total else f'  {k}: {v}')

---
## 6. Mechanistic Scoring

**Strategy for T4 (15GB VRAM):**
- **Pass 1:** `output_hidden_states=True, output_attentions=False` → compute ICS, POS, PFS
- **Pass 2:** `output_attentions=True, output_hidden_states=False` with `max_length=256` → compute BAS only

This avoids the OOM from requesting both simultaneously.

In [None]:
import torch.nn.functional as F

print(f'Loading {MODEL_NAME} for scoring (eager attention)...')
scorer_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype=torch.float16, device_map='auto', attn_implementation='eager',
)
scorer_model.eval()
NUM_LAYERS = scorer_model.config.num_hidden_layers
NUM_HEADS = scorer_model.config.num_attention_heads
print(f'Ready: {NUM_LAYERS} layers, {NUM_HEADS} heads, VRAM: {torch.cuda.memory_allocated()/1e9:.1f} GB')

In [None]:
def find_citation_token_positions(essay, tokenizer, max_len):
    """Map character-level citation positions to token positions."""
    full_text = essay['prompt'] + essay['response']
    enc = tokenizer(full_text, return_tensors='pt', return_offsets_mapping=True,
                    truncation=True, max_length=max_len)
    offsets = enc.offset_mapping[0].tolist()
    prompt_tok_len = len(tokenizer(essay['prompt'], return_tensors='pt').input_ids[0])

    results = []
    for cite in essay['citations']:
        char_start = cite['start_pos'] + len(essay['prompt'])
        tok_pos = None
        for i, (ts, te) in enumerate(offsets):
            if ts <= char_start < te:
                tok_pos = i; break
        if tok_pos is not None and tok_pos < enc.input_ids.shape[1]:
            results.append((cite, tok_pos))
    return enc, prompt_tok_len, results

print('Utility functions defined.')

In [None]:
all_scores = []

for essay in tqdm(essays, desc='Scoring'):
    if not essay['citations']: continue

    # --- PASS 1: Hidden states (ICS, POS, PFS) ---
    enc, prompt_len, cite_positions = find_citation_token_positions(essay, tokenizer, SCORE_MAX_LEN)
    if not cite_positions: continue

    try:
        with torch.no_grad():
            out1 = scorer_model(input_ids=enc.input_ids.cuda(),
                                output_hidden_states=True, output_attentions=False, return_dict=True)
        hs = [h.cpu() for h in out1.hidden_states]
        del out1; torch.cuda.empty_cache()
    except Exception as e:
        print(f'  HS error {essay["prompt_id"]}: {e}'); continue

    # --- PASS 2: Attention (BAS) — shorter max_len to fit ---
    enc2, prompt_len2, cite_positions2 = find_citation_token_positions(essay, tokenizer, 256)
    bas_per_cite = {}  # tok_pos -> bas_scores
    try:
        with torch.no_grad():
            out2 = scorer_model(input_ids=enc2.input_ids.cuda(),
                                output_hidden_states=False, output_attentions=True, return_dict=True)
        # attentions is a tuple of (batch, heads, seq, seq) per layer
        for cite, tok_pos in cite_positions2:
            bas_layers = []
            for layer_idx in range(NUM_LAYERS):
                # FACTUM BAS: attention FROM citation token TO position 0 (BOS), averaged across heads
                # attn shape: [batch, heads, seq, seq]
                attn_layer = out2.attentions[layer_idx]  # (1, heads, seq, seq)
                bos_attn = attn_layer[0, :, tok_pos, 0].mean().item()  # mean across heads
                bas_layers.append(bos_attn)
            bas_per_cite[tok_pos] = bas_layers
        del out2; torch.cuda.empty_cache()
    except Exception as e:
        print(f'  BAS error {essay["prompt_id"]}: {e}')

    # --- Combine scores per citation ---
    for cite, tok_pos in cite_positions:
        ics_scores, pos_scores, pfs_scores = [], [], []
        for layer in range(NUM_LAYERS):
            pre = hs[layer][0, tok_pos]
            post = hs[layer + 1][0, tok_pos]
            # POS: cosine orthogonality
            cos = F.cosine_similarity(pre.unsqueeze(0), post.unsqueeze(0)).item()
            pos_scores.append(1.0 - abs(cos))
            # PFS: magnitude of hidden state update
            pfs_scores.append(torch.norm(post - pre).item() / 2)
            # ICS: cosine(citation_hs, mean_essay_hs)
            hs_essay = hs[layer + 1][0, prompt_len:tok_pos]
            if hs_essay.shape[0] > 0:
                ics = F.cosine_similarity(post.unsqueeze(0), hs_essay.mean(0).unsqueeze(0)).item()
            else:
                ics = 0.0
            ics_scores.append(ics)

        # BAS from pass 2 (may not exist if tok_pos changed due to shorter truncation)
        bas_scores = bas_per_cite.get(tok_pos, [0.0] * NUM_LAYERS)

        all_scores.append({
            'ics_scores': ics_scores, 'ics_mean': sum(ics_scores)/len(ics_scores), 'ics_final': ics_scores[-1],
            'pos_scores': pos_scores, 'pos_mean': sum(pos_scores)/len(pos_scores), 'pos_final': pos_scores[-1],
            'pfs_scores': pfs_scores, 'pfs_mean': sum(pfs_scores)/len(pfs_scores), 'pfs_final': pfs_scores[-1],
            'bas_scores': bas_scores, 'bas_mean': sum(bas_scores)/len(bas_scores),
            'prompt_id': essay['prompt_id'], 'label': cite.get('label', 'unverified'),
            'citation': cite['raw_text'], 'tok_pos': tok_pos,
            'token': tokenizer.decode(enc.input_ids[0, tok_pos]),
        })

    del hs; torch.cuda.empty_cache()

print(f'\nScored {len(all_scores)} citations')

---
## 7. Results

In [None]:
import pandas as pd

if all_scores:
    df = pd.DataFrame(all_scores)
    cols = [c for c in ['prompt_id','citation','label','ics_mean','pos_mean','pfs_mean','bas_mean'] if c in df.columns]
    display(df[cols].round(4))
    print('\n=== Averages by Label ===')
    num = [c for c in ['ics_mean','pos_mean','pfs_mean','bas_mean'] if c in df.columns]
    if len(df['label'].unique()) > 1:
        print(df.groupby('label')[num].mean().round(4).to_string())
    else:
        print(f'All citations have label: {df["label"].iloc[0]}')
        print(df[num].describe().round(4).to_string())
else:
    print('No citations scored.')

In [None]:
import matplotlib.pyplot as plt

if all_scores:
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    score_keys = [('ics_scores','ICS'), ('pos_scores','POS'), ('pfs_scores','PFS'), ('bas_scores','BAS')]

    for ax, (key, title) in zip(axes, score_keys):
        for label in df['label'].unique():
            sub = df[df['label'] == label]
            means = []
            for l in range(NUM_LAYERS):
                vals = [s[l] for s in sub[key] if len(s) > l]
                means.append(sum(vals)/len(vals) if vals else 0)
            color = {'real': 'green', 'fabricated': 'red', 'unverified': 'gray'}.get(label, 'blue')
            ax.plot(range(NUM_LAYERS), means, label=label, color=color, alpha=0.8, linewidth=2)
        ax.set_title(title, fontsize=14, fontweight='bold')
        ax.set_xlabel('Layer'); ax.legend(); ax.grid(True, alpha=0.3)

    plt.suptitle(f'Mechanistic Scores — {MODEL_NAME.split("/")[-1]}', fontsize=16, fontweight='bold')
    plt.tight_layout(); plt.savefig('results/layer_scores.png', dpi=150, bbox_inches='tight')
    plt.show()
    print('Saved to results/layer_scores.png')
else:
    print('No scores to plot.')

In [None]:
# Save results
scores_out = [{k: ([float(x) for x in v] if isinstance(v, list) else v) for k, v in s.items()} for s in all_scores]
with open('results/scores.json', 'w') as f: json.dump(scores_out, f, indent=2)
with open('results/essays.json', 'w') as f: json.dump(essays, f, indent=2)
print(f'Saved {len(scores_out)} scores and {len(essays)} essays to results/')

---
## 8. Summary Statistics

In [None]:
if all_scores:
    print(f'Model: {MODEL_NAME}')
    print(f'Essays: {len(essays)}')
    print(f'Total citations: {sum(len(e["citations"]) for e in essays)}')
    print(f'Scored citations: {len(all_scores)}')
    print(f'Labels: {dict(df["label"].value_counts())}')
    print(f'\nScore ranges:')
    for s in ['ics_mean', 'pos_mean', 'pfs_mean', 'bas_mean']:
        if s in df.columns:
            print(f'  {s}: {df[s].min():.4f} — {df[s].max():.4f} (mean {df[s].mean():.4f})')