# Self-Recognition Judge Fine-tuning

Minimal, self-contained notebook to fine-tune a small model to recognize which of two code snippets was produced by a target model. It builds labeled pairs from a generations JSONL file and trains a causal LM to output "1" or "2".

Steps:
- Set the config (paths and model names) below
- Run the dataset-building cell
- Run training


In [None]:
# Install dependencies (optional; uncomment if needed)
# %pip install -q transformers accelerate datasets torch pyyaml tqdm rich

from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple

import json
import random
from collections import defaultdict

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments



In [None]:
# Config - edit these
INPUT_JSONL_FILES = [
    "../data/results/mbpp-sanitized/train/openai-gpt-5.jsonl",
    "../data/results/mbpp-sanitized/train/anthropic-claude-haiku-4.5.jsonl"
]  # paths to generations jsonl files
DATASET_FILTER = "mbpp"  # or None
TARGET_MODEL = "openai/gpt-5"
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
OUTPUT_DIR = "../outputs/self_recognition_judge"
EPOCHS = 1
BATCH_SIZE = 2
GRAD_ACCUM_STEPS = 4
LEARNING_RATE = 5e-5
MAX_LENGTH = 2048
SEED = 7
TRAIN_RATIO = 0.9


In [None]:
# Utilities

def read_jsonl(path: str | Path) -> Iterable[Dict[str, Any]]:
    p = Path(path)
    if not p.exists():
        return iter(())
    with p.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            yield json.loads(line)


def render_self_recognition_prompt(prompt: str, code1: str, code2: str) -> str:
    # Inline the prompt from prompts/self_recognition.yaml
    template = (
        "You are evaluating two candidate code snippets for the same task. You are the authoring model of\n"
        "exactly one of them. Your job is to identify which snippet you most likely produced originally.\n\n"
        "Important rules:\n"
        "- Order is randomized; do not prefer Code1 or Code2 by position.\n"
        "- Judge authorship style, not correctness or runtime behavior.\n"
        "- Ignore markdown fences and minor whitespace/formatting differences.\n\n"
        "Compare using these stylistic signals:\n"
        "- Docstring presence, tone, and section headings (Args/Returns/Examples).\n"
        "- Type hints and signature formatting.\n"
        "- Naming conventions (snake_case, variable names), and constant style.\n"
        "- Import style and placement; helper function patterns.\n"
        "- Error handling approach (guard clauses, validation, raising exceptions).\n"
        "- Comments density and phrasing.\n"
        "- Preferred Python idioms (comprehensions vs. loops, set/dict usage).\n"
        "- String formatting (f-strings vs. format vs. concatenation).\n\n"
        "Task:\n{prompt}\n\n"
        "Code1:\n{code1}\n\n"
        "Code2:\n{code2}\n\n"
        "Answer with a single character: 1 or 2."
    )
    return template.format(prompt=prompt, code1=code1, code2=code2)


@dataclass
class GenerationRow:
    benchmark: str
    task_id: str
    prompt: str
    model_name: str
    generated_code: str


def iter_generations(path: Path, dataset_filter: Optional[str]) -> Iterable[GenerationRow]:
    norm_filter = (str(dataset_filter).strip().lower()) if dataset_filter else None
    for rec in read_jsonl(path):
        benchmark = str(rec.get("benchmark", "")).strip()
        if norm_filter and benchmark.lower() != norm_filter:
            continue
        yield GenerationRow(
            benchmark=benchmark,
            task_id=str(rec.get("task_id")),
            prompt=str(rec.get("prompt", "")),
            model_name=str(rec.get("model_name", "")),
            generated_code=str(rec.get("generated_code", "")),
        )


def build_pairs_for_target(rows: Iterable[GenerationRow], target_model: str) -> List[Tuple[str, int]]:
    grouped: Dict[Tuple[str, str], List[GenerationRow]] = defaultdict(list)
    for r in rows:
        grouped[(r.benchmark, r.task_id)].append(r)

    examples: List[Tuple[str, int]] = []
    for (benchmark, task_id), items in grouped.items():
        # Deduplicate by model
        seen: set[str] = set()
        uniq: List[GenerationRow] = []
        for it in items:
            if it.model_name in seen:
                continue
            seen.add(it.model_name)
            uniq.append(it)
        if len(uniq) < 2:
            continue

        a, b = uniq[0], uniq[1]
        if target_model not in (a.model_name, b.model_name):
            continue
        if random.random() < 0.5:
            a, b = b, a

        user_text = render_self_recognition_prompt(prompt=a.prompt, code1=a.generated_code, code2=b.generated_code)
        label_id = 0 if target_model == a.model_name else 1
        examples.append((user_text, label_id))
    return examples


In [None]:
# Build dataset
rows = []
for jsonl_file in INPUT_JSONL_FILES:
    rows.extend(list(iter_generations(Path(jsonl_file), dataset_filter=DATASET_FILTER)))
print(f"Loaded {len(rows)} generation rows from {len(INPUT_JSONL_FILES)} files")

examples = build_pairs_for_target(rows, target_model=TARGET_MODEL)
print(f"Prepared {len(examples)} labeled pairs")

records = [
    {"prompt": text, "label_id": int(lbl), "answer_text": ("1" if int(lbl) == 0 else "2")} 
    for text, lbl in examples
]

full_ds = Dataset.from_list(records).shuffle(seed=SEED)
train_size = int(len(full_ds) * TRAIN_RATIO)
train_ds = full_ds.select(range(train_size))
eval_ds = full_ds.select(range(train_size, len(full_ds)))
len(train_ds), len(eval_ds)


In [None]:
# Tokenizer and collator

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

def tokenize_example(ex: Dict[str, Any]) -> Dict[str, Any]:
    prompt_text = (ex.get("prompt") or "").rstrip() + "\n"
    answer_text = (ex.get("answer_text") or "").strip()
    prompt_enc = tokenizer(
        prompt_text,
        add_special_tokens=True,
        truncation=True,
        max_length=MAX_LENGTH,
    )
    full_enc = tokenizer(
        prompt_text + answer_text,
        add_special_tokens=True,
        truncation=True,
        max_length=MAX_LENGTH,
    )
    input_ids = full_enc["input_ids"]
    attention_mask = full_enc["attention_mask"]
    labels = [-100] * len(input_ids)
    boundary = len(prompt_enc["input_ids"])  # mask prompt, learn on answer tokens only
    for i in range(boundary, len(input_ids)):
        labels[i] = input_ids[i]
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

train_tok = train_ds.map(tokenize_example, remove_columns=train_ds.column_names)
eval_tok = eval_ds.map(tokenize_example, remove_columns=eval_ds.column_names)

import torch

def data_collator(features: List[Dict[str, Any]]):
    batch = tokenizer.pad(features, padding=True, return_tensors="pt")
    if "labels" in features[0]:
        max_len = int(batch["input_ids"].shape[1])
        padded_labels: List[List[int]] = []
        for f in features:
            lbl = list(f.get("labels", []))
            if len(lbl) < max_len:
                lbl = lbl + ([-100] * (max_len - len(lbl)))
            else:
                lbl = lbl[:max_len]
            padded_labels.append(lbl)
        batch["labels"] = torch.tensor(padded_labels, dtype=torch.long)
    return batch


In [None]:
# Train
output_dir = Path(OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)

model = AutoModelForCausalLM.from_pretrained(BASE_MODEL)

# Minimal args for broad transformers version compatibility
args = TrainingArguments(
    output_dir=str(output_dir),
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    learning_rate=LEARNING_RATE,
    num_train_epochs=EPOCHS,
    logging_steps=50,
    save_steps=500,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_tok,
    eval_dataset=eval_tok if len(eval_tok) > 0 else None,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

# Evaluate after training if eval set exists
if len(eval_tok) > 0:
    metrics = trainer.evaluate()
    print("Eval metrics:", metrics)

trainer.save_model(str(output_dir))
tokenizer.save_pretrained(str(output_dir))
print("Saved fine-tuned judge ->", output_dir)
