### Tinker phishing classifier (clean)

This notebook runs **inference** with your fine-tuned Tinker sampler and produces a final DataFrame that includes:
- `category`: the model prediction (`GENUINE`/`PHISHING`/`ASSIGN_TO_AGENT`)
- `agent_notes_pred`: the model-provided agent notes extracted from the assistant text (e.g. `agent_notes: ...`)

**Prereqs**
- `TINKER_API_KEY` is set (this notebook loads `.env` if present).
- You have a completed run log with `checkpoints.jsonl` at `LOG_PATH` (default: `/tmp/tinker-examples/sl_ar_phishing`).

**Outputs**
- `df_tinker`: per-ticket predictions + `agent_notes_pred`
- `df_phishing_messages_tinker`: your evaluation DataFrame with predictions merged in


In [1]:
# Imports + environment

from __future__ import annotations

import asyncio
import json
import re
from enum import StrEnum
from pathlib import Path
from typing import Any, Dict, List

import pandas as pd
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from tqdm.auto import tqdm

load_dotenv()


  from .autonotebook import tqdm as notebook_tqdm


True

In [2]:
lora_train = pd.read_parquet("../data/finetuning/lora_train_emails.parquet")
lora_test = pd.read_parquet("../data/finetuning/lora_test_emails.parquet")

display(lora_train["label"].value_counts(dropna=False))
display(lora_test["label"].value_counts(dropna=False))


label
1    2241
0    2231
Name: count, dtype: int64

label
1    560
0    558
Name: count, dtype: int64

### 1) Benchmark (Azure OpenAI, structured output)

This section runs the baseline classifier using Azure OpenAI with structured outputs (`VerificationClassification`).

- Requires `AZURE_OPENAI_API_KEY`, `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_API_VERSION`, `AZURE_OPENAI_DEPLOYMENT`
- Produces `df_gpt` and `df_phishing_messages_gpt`


In [4]:
# GPT-5.2 benchmark (spam) — parse `is_spam: 1/0`
# Uses the same lightweight label format as your fine-tuned Tinker model.

from dotenv import load_dotenv

# Make env loading robust even if the notebook's CWD isn't repo root.
load_dotenv(dotenv_path=Path("../.env"), override=False)

import os
import re
from typing import Any, Dict, List

from openai import AsyncAzureOpenAI, AsyncOpenAI

AZURE_OPENAI_API_KEY = os.environ.get("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_ENDPOINT = os.environ.get("AZURE_OPENAI_ENDPOINT")
AZURE_OPENAI_API_VERSION = os.environ.get("AZURE_OPENAI_API_VERSION")
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")

missing = [
    k
    for k, v in {
        "AZURE_OPENAI_API_KEY": AZURE_OPENAI_API_KEY,
        "AZURE_OPENAI_ENDPOINT": AZURE_OPENAI_ENDPOINT,
        "AZURE_OPENAI_API_VERSION": AZURE_OPENAI_API_VERSION,
        "OPENAI_API_KEY": OPENAI_API_KEY,
    }.items()
    if not v
]
if missing:
    raise RuntimeError(
        "Missing env vars: "
        + ", ".join(missing)
        + "\nSet them in your .env (or environment) and re-run this cell."
    )

# Azure client kept around (handy for other experiments), but this benchmark uses OPENAI_API_KEY.
azure_client = AsyncAzureOpenAI(
    api_key=AZURE_OPENAI_API_KEY,
    azure_endpoint=AZURE_OPENAI_ENDPOINT,
    api_version=AZURE_OPENAI_API_VERSION,
)
oai_client = AsyncOpenAI(api_key=OPENAI_API_KEY)

print("OpenAI client ready")


system_prompt = Path("../data/verification/system_prompt_spam.md").read_text()

def _build_email_user_prompt(*, subject: str, body: str) -> str:
    return f"""
Email subject:
---------------------------------------------------
\n{subject}\n\n
---------------------------------------------------
Email body:
---------------------------------------------------
\n{body}\n\n
---------------------------------------------------
""".strip()


def _parse_is_spam(text: str) -> bool | None:
    m = re.search(r"\bis_spam\s*:\s*(-1|0|1|true|false)\b", text, flags=re.IGNORECASE)
    if m is None:
        return None

    v = m.group(1).lower()
    if v == "-1":
        return None

    return v in ("1", "true")


async def classify_spam_gpt52(
    *,
    subject: str,
    body: str,
    system_prompt: str = system_prompt,
) -> tuple[bool | None, str]:
    user_prompt = _build_email_user_prompt(subject=subject, body=body)

    response = await oai_client.responses.create(
        model="gpt-5.2-2025-12-11",
        input=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt},
        ],
    )

    text = (getattr(response, "output_text", "") or "").strip()
    return _parse_is_spam(text), text


async def process_all_emails_gpt52(
    df_messages: pd.DataFrame,
    *,
    concurrency: int = 50,
    max_rows: int | None = None,
) -> List[Dict[str, Any]]:
    """Run GPT-5.2 spam classification over a dataframe of emails."""

    sem = asyncio.Semaphore(concurrency)
    rows = list(df_messages.iterrows())
    if max_rows is not None:
        rows = rows[:max_rows]

    async def process_single(row_idx: Any, row: pd.Series) -> Dict[str, Any]:
        async with sem:
            try:
                is_spam_pred, raw_text = await classify_spam_gpt52(
                    subject=str(row.get("subject", "")),
                    body=str(row.get("body", "")),
                )
                return {
                    "ticket_id": row.get("ticket_id", row_idx),
                    "is_spam_pred": is_spam_pred,
                    "raw": raw_text,
                    "success": True,
                    "error": None,
                }
            except Exception as e:
                return {
                    "ticket_id": row.get("ticket_id", row_idx),
                    "is_spam_pred": None,
                    "raw": None,
                    "success": False,
                    "error": repr(e),
                }

    tasks = [asyncio.create_task(process_single(row_idx, row)) for row_idx, row in rows]

    results: list[Dict[str, Any]] = []
    for fut in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
        results.append(await fut)

    return results


gpt_results = await process_all_emails_gpt52(lora_test.iloc[:5], concurrency=50, max_rows=None)
df_gpt = pd.DataFrame(gpt_results)

# If you want to persist results:
# df_gpt.to_parquet("../data/verification/gpt52_results_spam.parquet")
# df_gpt = pd.read_parquet("../data/verification/gpt52_results_spam.parquet")

# Merge predictions back into the evaluation dataframe.
if "ticket_id" in lora_test.columns:
    df_eval = lora_test
else:
    df_eval = lora_test.reset_index().rename(columns={"index": "ticket_id"})

df_spam_messages_gpt = df_eval.merge(
    df_gpt[["ticket_id", "is_spam_pred", "raw", "success", "error"]],
    on="ticket_id",
    how="inner",
)



display(df_spam_messages_gpt["is_spam_pred"].value_counts(dropna=False))


OpenAI client ready


100%|██████████| 5/5 [00:01<00:00,  4.34it/s]


is_spam_pred
False    3
True     2
Name: count, dtype: int64

In [5]:
# Quick metrics (optional)

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix

# Ground truth for the email dataset is `label` (bool/int).
# Predictions come from GPT as `is_spam_pred` (bool | None).
mask = df_spam_messages_gpt["is_spam_pred"].notna()

y_true = df_spam_messages_gpt.loc[mask, "label"].astype(bool)
y_pred = df_spam_messages_gpt.loc[mask, "is_spam_pred"].astype(bool)

print("GPT accuracy:", accuracy_score(y_true, y_pred))

precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
cm = confusion_matrix(y_true, y_pred)
# [[TN, FP],
#  [FN, TP]]

print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")
print(f"Confusion Matrix: {cm}")


GPT accuracy: 1.0
Precision: 1.0000, Recall: 1.0000, F1 Score: 1.0000
Confusion Matrix: [[3 0]
 [0 2]]


### 3) Create a SamplingClient for the fine-tuned checkpoint

We load the last sampler checkpoint from `LOG_PATH/checkpoints.jsonl` and create a `sampling_client` for inference.


In [6]:
# Make `tinker_cookbook` importable without installing it.

import sys

def find_tinker_cookbook_repo(start: Path) -> Path:
    candidates: list[Path] = []
    for p in [start, *start.parents]:
        candidates.append(p / "tinker-cookbook")
        candidates.append(p)
    for repo in candidates:
        if (repo / "tinker_cookbook" / "__init__.py").exists():
            return repo
    raise FileNotFoundError(
        "Could not find the tinker-cookbook repo (missing tinker_cookbook/__init__.py)."
    )

TINKER_COOKBOOK_REPO = find_tinker_cookbook_repo(Path.cwd())
if str(TINKER_COOKBOOK_REPO) not in sys.path:
    sys.path.insert(0, str(TINKER_COOKxBOOK_REPO))

import tinker_cookbook  # noqa: F401

print("tinker-cookbook repo:", TINKER_COOKBOOK_REPO)


tinker-cookbook repo: /Users/ext-elias.melas/Documents/Gitcode/tinker-cookbook


In [7]:
# Choose where inference runs:
# - True: run LOCALLY using your downloaded adapter at ../adapters/lora_adapter
# - False: run REMOTELY on Tinker using the sampler_path from LOG_PATH/checkpoints.jsonl
USE_LOCAL_ADAPTER = False

# --- Local adapter inference (Transformers + PEFT) ---
if USE_LOCAL_ADAPTER:
    ####### TODO (WIP) ########
    import torch
    from peft import PeftModel
    from transformers import AutoModelForCausalLM, AutoTokenizer

    BASE_MODEL_DIR = Path("../models/llama-3.1-8b")
    ADAPTER_DIR = Path("../adapters/lora_adapter_spam")

    if not BASE_MODEL_DIR.exists():
        raise FileNotFoundError(f"Base model not found: {BASE_MODEL_DIR}")
    if not ADAPTER_DIR.exists():
        raise FileNotFoundError(f"Adapter not found: {ADAPTER_DIR}")

    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
        else "cpu"
    )
    dtype = torch.float16 if device in {"cuda", "mps"} else torch.float32

    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_DIR)
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token = tokenizer.eos_token

    base_model = AutoModelForCausalLM.from_pretrained(BASE_MODEL_DIR, torch_dtype=dtype)
    base_model.to(device)

    model = PeftModel.from_pretrained(base_model, ADAPTER_DIR)
    model.eval()

    def _render_role_colon(messages: list[dict[str, str]]) -> str:
        # Match the role_colon style used during training.
        parts: list[str] = []
        for m in messages:
            role = (m.get("role") or "").strip().lower()
            content = (m.get("content") or "").strip()
            if role == "system":
                parts.append(f"System: {content}")
            elif role == "user":
                parts.append(f"User: {content}")
            elif role == "assistant":
                parts.append(f"Assistant: {content}")
            else:
                parts.append(f"{role.title()}: {content}")
        # Add generation cue
        return "\n\n".join(parts) + "\n\nAssistant:"

    async def tinker_completer(messages: list[dict[str, str]]) -> dict[str, str]:
        # Keep the same interface as TinkerMessageCompleter: returns {role, content}.
        prompt = _render_role_colon(messages)

        inputs = tokenizer(prompt, return_tensors="pt")
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.inference_mode():
            out = model.generate(
                **inputs,
                max_new_tokens=64,
                do_sample=False,
                temperature=0.0,
                pad_token_id=tokenizer.eos_token_id,
            )

        gen = out[0][inputs["input_ids"].shape[-1] :]
        text = tokenizer.decode(gen, skip_special_tokens=True).strip()

        # Trim if the model starts a new turn
        text = text.split("\n\nUser:", 1)[0].strip()

        return {"role": "assistant", "content": text}

    print("Local inference enabled")
    print("Device:", device)
    print("Base model:", BASE_MODEL_DIR)
    print("Adapter:", ADAPTER_DIR)

# --- Remote Tinker inference (SamplingClient) ---
else:
    import tinker
    from tinker_cookbook.checkpoint_utils import get_last_checkpoint

    # Update this if your run logs are elsewhere.
    LOG_PATH = Path("/tmp/tinker-examples/sl_ar_phishing_spam")

    service_client = tinker.ServiceClient()

    ckpt_sampler = get_last_checkpoint(str(LOG_PATH), required_key="sampler_path")
    print("Last sampler checkpoint:", ckpt_sampler.get("sampler_path") if ckpt_sampler else None)

    if not ckpt_sampler:
        raise RuntimeError(
            "No sampler checkpoint found. Ensure training ran and wrote checkpoints.jsonl under LOG_PATH."
        )

    sampling_client = service_client.create_sampling_client(model_path=ckpt_sampler["sampler_path"])
    print("Sampling client ready")


Last sampler checkpoint: tinker://ab71f6d1-f9b1-5f68-b387-59bb7b34c066:train:0/sampler_weights/final
Sampling client ready


### 4) Build the Tinker message completer

 In the Tinker Cookbook, policies are implemented as Completers. Completers are abstractions that represent models or policies that can be sampled from, providing different levels of structure depending on your use case.

The Tinker Cookbook provides two main types of completers, each designed for different use cases:
- TokenCompleter: Operates on tokens and is used by RL algorithms
- MessageCompleter: Operates on messages and needs to be used with a renderer
The choice between these depends on whether you're working at the token level for RL training or at the message level for interacting with and evaluating the model.

We build a renderer/tokenizer matching the base model family and then create a `TinkerMessageCompleter` that returns assistant messages.


In [8]:
# Build the message completer
# - If USE_LOCAL_ADAPTER=True, `tinker_completer` was defined in the previous cell (local Transformers+PEFT).
# - Otherwise, create a remote TinkerMessageCompleter.

MODEL_NAME = "meta-llama/Llama-3.1-8B"

if not USE_LOCAL_ADAPTER:
    from tinker_cookbook import model_info
    from tinker_cookbook.completers import TinkerMessageCompleter
    from tinker_cookbook import renderers as cookbook_renderers
    from tinker_cookbook.tokenizer_utils import get_tokenizer

    RENDERER_NAME = model_info.get_recommended_renderer_name(MODEL_NAME)

    tok = get_tokenizer(MODEL_NAME)
    renderer = cookbook_renderers.get_renderer(RENDERER_NAME, tok)

    tinker_completer = TinkerMessageCompleter(
        sampling_client=sampling_client,
        renderer=renderer,
        max_tokens=64,
    )

    print("Remote Tinker inference enabled")
    print("Model:", MODEL_NAME)
    print("Renderer:", RENDERER_NAME)
    print("Stop sequences:", renderer.get_stop_sequences())
else:
    print("Local adapter inference enabled")
    print("Model:", MODEL_NAME)


Remote Tinker inference enabled
Model: meta-llama/Llama-3.1-8B
Renderer: role_colon
Stop sequences: ['\n\nUser:']


### 4) Classify with Tinker adapter

Your fine-tuned model returns assistant text like:
- `is_phishing: True, agent_notes: ...`

We parse both fields and store the notes as `agent_notes_pred`.


In [23]:
class SpamClassification(BaseModel):
    is_spam: bool | None = Field(...)


# Spam system prompt for the email classifier fine-tune.
system_prompt = Path("../data/verification/system_prompt_spam.md").read_text()


def _build_user_prompt(*, subject: str, body: str) -> str:
    return f"""
Email subject:
---------------------------------------------------
\n{subject}\n\n
---------------------------------------------------
Email body:
---------------------------------------------------
\n{body}\n\n
---------------------------------------------------
"""


async def classify_phishing_tinker(
    *,
    subject: str,
    body: str,
    system_prompt: str = system_prompt,
) -> tuple[VerificationClassification, str | None]:
    user_prompt = _build_user_prompt(subject=subject, body=body)

    convo = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]

    assistant_msg = await tinker_completer(convo)
    print(assistant_msg)
    text = str(assistant_msg.get("content", "")).strip()

    # Expected assistant output (common formats):
    # - "is_spam: 1"
    # - "is_spam: 0"
    # - "is_spam: True" / "is_spam: False"
    m = re.search(r"\bis_spam\s*:\s*(1|0|true|false)\b", text, flags=re.IGNORECASE)
    if m is None:
        return SpamClassification(is_spam=False), None

    v = m.group(1).lower()
    is_spam = v in ("1", "true")

    # NOTE: We reuse the original PHISHING/GENUINE categories.
    # Here, PHISHING corresponds to the positive class (spam).
    is_spam = SpamClassification(is_spam=is_spam)

    notes_match = re.search(r"\bagent_notes\s*:\s*(.*)$", text, flags=re.IGNORECASE | re.DOTALL)
    agent_notes_pred = notes_match.group(1).strip() if notes_match else None
    if agent_notes_pred == "":
        agent_notes_pred = None

    return is_spam, agent_notes_pred


In [46]:
async def process_all_messages_phishing_tinker(
    df_messages: pd.DataFrame,
    *,
    concurrency: int = 16,
    max_rows: int | None = None,
) -> List[Dict[str, Any]]:
    """Run Tinker inference over a dataframe of rows."""

    sem = asyncio.Semaphore(concurrency)
    rows = list(df_messages.iterrows())
    if max_rows is not None:
        rows = rows[:max_rows]

    async def process_single(row_idx: Any, row: pd.Series) -> Dict[str, Any]:
        async with sem:
            pred, agent_notes_pred = await classify_phishing_tinker(
                subject=str(row.get("subject", "")),
                body=str(row.get("body", "")),
            )

            # Store category as a plain string for easier pandas usage.
            is_spam = pred.is_spam.value if hasattr(pred.is_spam, "value") else str(pred.is_spam)

            # Some datasets don't have ticket_id; fall back to the row index.
            ticket_id = row.get("ticket_id", row_idx)

            return {
                "ticket_id": ticket_id,
                "is_spam_pred": is_spam,
                "success": True,
                "error": None,
            }

    tasks = [asyncio.create_task(process_single(row_idx, row)) for row_idx, row in rows]

    results: list[Dict[str, Any]] = []
    for fut in tqdm(asyncio.as_completed(tasks), total=len(tasks)):
        results.append(await fut)

    return results


tinker_results = await process_all_messages_phishing_tinker(lora_test.iloc[:10], max_rows=None)
df_tinker = pd.DataFrame(tinker_results)
# df_tinker.to_parquet("../data/verification/tinker_results_spam.parquet")
# df_tinker = pd.read_parquet("../data/verification/tinker_results_spam.parquet")
df_tinker.head()


 30%|███       | 3/10 [00:01<00:02,  2.88it/s]

{'role': 'assistant', 'content': 'is_spam: 1'}
{'role': 'assistant', 'content': 'is_spam: 0'}
{'role': 'assistant', 'content': 'is_spam: 0'}


100%|██████████| 10/10 [00:01<00:00,  6.26it/s]

{'role': 'assistant', 'content': 'is_spam: 0'}
{'role': 'assistant', 'content': 'is_spam: 0'}
{'role': 'assistant', 'content': 'is_spam: 1'}
{'role': 'assistant', 'content': 'is_spam: 1'}
{'role': 'assistant', 'content': 'is_spam: 1'}
{'role': 'assistant', 'content': 'is_spam: 0'}
{'role': 'assistant', 'content': 'is_spam: 0'}





Unnamed: 0,ticket_id,is_spam_pred,success,error
0,1451,True,True,
1,808,False,True,
2,2659,False,True,
3,3829,False,True,
4,5550,False,True,


In [69]:
# Merge predictions back into the evaluation dataframe.
# Some datasets (like Enron) don't have a ticket_id column.
if "ticket_id" in lora_test.columns:
    df_eval = lora_test
else:
    df_eval = lora_test.reset_index().rename(columns={"index": "ticket_id"})

df_phishing_messages_tinker = df_eval.merge(
    df_tinker[["ticket_id","is_spam_pred"]],
    on="ticket_id",
    how="inner",
)
df_phishing_messages_tinker['pred_bool'] = df_phishing_messages_tinker['is_spam_pred'].map(lambda x: True if x == "True" else False)

display(
    df_phishing_messages_tinker["is_spam_pred"].value_counts(dropna=False),
    df_phishing_messages_tinker
)


is_spam_pred
False    6
True     4
Name: count, dtype: int64

Unnamed: 0,ticket_id,subject,body,label,is_spam_pred,pred_bool
0,5438,cera conference call and web presentation : wi...,scenarios . . . - cera conference call\r\ncera...,0,False,False
1,808,"revision # 1 - hpl nom for july 25 , 2000",( see attached file : hplo 725 . xls )\r\n- hp...,0,False,False
2,4497,you have won congratulation ! ! lucky winner !...,dayzer lottery national promotion .\r\npostbus...,1,True,True
3,2662,"ambllen , alprazzolam , \\ / aluum , \\ / llgr...",terrible showed gotten teacher among . speech ...,1,True,True
4,3829,hr generalist for your group,norma villarreal is the hr generalist for all ...,0,False,False
5,1451,appointment on sunday at 18 - 00,remove\r\nperch brunoguffaw handicapper conjug...,1,True,True
6,2083,fw : epmi / aep / allegheny ring transaction c...,below please find proposed confirms for the pr...,0,False,False
7,5550,february 7 th update,"jeff / michelle / ken ,\r\nhere is the daily u...",0,False,False
8,2659,asking for advice regarding summer associate p...,"shirley ,\r\nplease , set up a phone interview...",0,False,False
9,2806,promote your business,the power of email marketing\r\nemail marketin...,1,True,True


In [None]:
# df_phishing_messages_tinker.dropna(subset=['category'],inplace=True)

In [66]:
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix


# Ground truth for the email dataset is `label` (bool/int).
y_true = (
    df_phishing_messages_tinker["label"]
).astype(bool)

y_pred = (
    df_phishing_messages_tinker["pred_bool"]
)

precision = precision_score(y_true, y_pred)
recall = recall_score(y_true, y_pred)
f1 = f1_score(y_true, y_pred)
cm = confusion_matrix(y_true, y_pred)
# [[TN, FP],
#  [FN, TP]]

print("Fine-tuned Llama 3.1-8B accuracy:", accuracy_score(y_true, y_pred))
print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1 Score: {f1:.4f}")
print(f"Confusion Matrix: {cm}")

Fine-tuned Llama 3.1-8B accuracy: 1.0
Precision: 1.0000, Recall: 1.0000, F1 Score: 1.0000
Confusion Matrix: [[6 0]
 [0 4]]
