In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 1 – Imports, paths, reproducibility 🔧📂  ║
# ╚════════════════════════════════════════════════╝
import os, json, random
from pathlib import Path
from typing import List, Dict

import numpy as np
import pandas as pd
from dotenv import load_dotenv
from openai import AzureOpenAI
from tenacity import retry, wait_random_exponential, stop_after_attempt
from tqdm.auto import tqdm

SEED = 42
random.seed(SEED)
np.random.seed(SEED)

# --- environment ---
load_dotenv()                                      # ENDPOINT_URL, DEPLOYMENT_NAME, AZURE_OPENAI_API_KEY
DEPLOYMENT_NAME  = os.getenv("DEPLOYMENT_NAME", "").strip()
AZ_VERSION       = os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview")

# --- directories / files ---
PROMPT_V1_PATH   = Path("prompts/v4.txt")
DATASET_PATH     = Path("outputs/datasets/train_dataset.csv")   # always use TRAIN here
MISCLASS_DIR     = Path("outputs") / "gpt-4.1" / "train" / "misclassified"
PROMPT_V2_PATH   = Path("prompts/v5.txt")                       # new prompt will be written here

for p in [PROMPT_V1_PATH, DATASET_PATH, MISCLASS_DIR]:
    if not p.exists():
        raise FileNotFoundError(f"Missing expected path: {p}")


In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 2 – Collect misclassified examples 🗂️    ║
# ╚════════════════════════════════════════════════╝
# Load train split so we can attach full title/abstract
train_df = pd.read_csv(DATASET_PATH)

def load_example(json_path: Path) -> Dict:
    with open(json_path, "r", encoding="utf-8") as f:
        rec = json.load(f)
    rec["id"] = json_path.stem
    return rec

records = [load_example(p) for p in MISCLASS_DIR.glob("*.json")]

false_pos, false_neg = [], []
for rec in records:
    gt, pred = rec["ground_truth"], rec["prediction"]
    if gt == "Excluded" and pred == "Included":
        false_pos.append(rec["id"])
    elif gt == "Included" and pred == "Excluded":
        false_neg.append(rec["id"])

# Keep up to 10 of each, deterministic for reproducibility
random.shuffle(false_pos)
random.shuffle(false_neg)
false_pos = false_pos#[:10]
false_neg = false_neg#[:10]

def fetch_rows(id_list: List[str]) -> List[Dict]:
    subset = train_df[train_df["id"].isin(id_list)]
    out = []
    for _, r in subset.iterrows():
        out.append({
            "id"      : r["id"],
            "title"   : r["title"],
            "abstract": r["abstract"],
            "label"   : r["label"]
        })
    return out

fp_rows = fetch_rows(false_pos)
fn_rows = fetch_rows(false_neg)

print(f"Selected {len(fp_rows)} false-positives and {len(fn_rows)} false-negatives")


In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 3 – Build meta-prompt for the LLM ✍️      ║
# ╚════════════════════════════════════════════════╝
orig_prompt = PROMPT_V1_PATH.read_text(encoding="utf-8")

def format_example(ex: Dict) -> str:
    return (
        f"\n---\nID: {ex['id']}\nTRUE LABEL: {ex['label']}\n"
        f"TITLE: {ex['title']}\nABSTRACT: {ex['abstract']}\n"
    )

examples_text  = "\nFALSE POSITIVES:" + "".join(format_example(e) for e in fp_rows)
examples_text += "\n\nFALSE NEGATIVES:" + "".join(format_example(e) for e in fn_rows)

meta_prompt = f"""
You are an expert prompt-engineer.

TASK:
Rewrite the ORIGINAL PROMPT so that an LLM classifies abstracts as **Included** or **Excluded** with higher accuracy.
Return ONLY the improved prompt text, nothing else.

CONSTRAINTS:
• Minimize your changes to the original prompt.
• The changes you make should be based on the misclassified examples provided.
• All changes should be small and incremental, not large rewrites.
• Keep the JSON schema unchanged.
• Use clear, concise instructions.

ORIGINAL PROMPT:
{orig_prompt}
Below are misclassified examples from the training split to help you refine the instructions:
{examples_text}

Respond with the new prompt only.
""".strip()

In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 4 - Call Azure OpenAI for new prompt 🤖   ║
# ╚════════════════════════════════════════════════╝
def get_client() -> AzureOpenAI:
    return AzureOpenAI(
        api_key=os.getenv("AZURE_OPENAI_API_KEY"),
        azure_endpoint=os.getenv("ENDPOINT_URL"),
        api_version=AZ_VERSION,
    )

client = get_client()


@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(5))
def call_llm(prompt: str) -> str:
    messages = [
        {"role": "system", "content": "You are a helpful AI assistant."},
        {"role": "user", "content": prompt},
    ]

    # Try a series of argument sets, falling back when the model rejects a parameter
    attempts = [
        {"max_completion_tokens": 4000, "temperature": 0.7, "top_p": 1},
        {"max_tokens": 4000, "temperature": 0.7, "top_p": 1},
        {"max_completion_tokens": 4000},
        {"max_tokens": 4000},
    ]

    last_err = None
    for kwargs in attempts:
        try:
            resp = client.chat.completions.create(
                model=DEPLOYMENT_NAME,
                messages=messages,
                **kwargs,
            )
            return resp.choices[0].message.content.strip()
        except Exception as e:
            # Save error and keep looping if it looks like a parameter issue
            msg = str(e)
            if (
                "unsupported_parameter" in msg
                or "unsupported_value" in msg
                or "does not support" in msg
            ):
                last_err = e
                continue
            raise  # re-raise unexpected errors
    # If we exit the loop, re-raise the last captured error
    raise last_err  # type: ignore[misc]


new_prompt_text = call_llm(meta_prompt)
print("LLM returned a prompt with", len(new_prompt_text.split()), "words")


In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 5 – Save new prompt 💾                    ║
# ╚════════════════════════════════════════════════╝
PROMPT_V2_PATH.parent.mkdir(parents=True, exist_ok=True)
PROMPT_V2_PATH.write_text(new_prompt_text, encoding="utf-8")
print("Saved improved prompt to:", PROMPT_V2_PATH.resolve())
