# 🚀🔎 Literature Screening – Multi-class Model Inference    
  
This notebook calls an Azure OpenAI deployment to produce **multi-class** predictions (and any free-text fields) for the `train` and `test` splits located in `../datasets/`.    
Outputs are saved inside `all_class_files/outputs/<MODEL_NAME>/<SPLIT>/` following the structure:  
  
```  
all_class_files/  
└── outputs/  
    └── <MODEL_NAME>/  
        ├── train/  
        │   ├── predictions/  
        │   └── unparsed/  
        └── test/  
            ├── predictions/  
            └── unparsed/  
```  
  
Run the notebook once for each split (`train` then `test`) to populate both folders.  

In [None]:
# ╔════════════════════════════════════════════════╗  
# ║ Cell 1 – Imports and environment setup 🌍       ║  
# ╚════════════════════════════════════════════════╝  
import os, json, random, shutil, time  
from pathlib import Path  
from typing import Dict, Any  
  
import numpy as np  
import pandas as pd  
  
from dotenv import load_dotenv  
from openai import AzureOpenAI  
from tenacity import retry, stop_after_attempt, wait_random_exponential  
from tqdm.auto import tqdm  
  
# -- Reproducibility -------------------------------------------------------- #  
SEED = 42  
random.seed(SEED)  
np.random.seed(SEED)  
  
# -- Environment ------------------------------------------------------------ #  
load_dotenv()    # expects ENDPOINT_URL, DEPLOYMENT_NAME, AZURE_OPENAI_API_KEY

In [None]:
# ╔════════════════════════════════════════════════╗  
# ║ Cell 2 – Azure OpenAI helper 🤖                ║  
# ╚════════════════════════════════════════════════╝  
def make_client() -> AzureOpenAI:  
    return AzureOpenAI(  
        api_key        = os.getenv("AZURE_OPENAI_API_KEY"),  
        azure_endpoint = os.getenv("ENDPOINT_URL"),  
        api_version    = os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview"),  
    )  
  
client          = make_client()  
DEPLOYMENT_NAME = os.getenv("DEPLOYMENT_NAME", "").strip()  
  
if not DEPLOYMENT_NAME:  
    raise ValueError("⚠️  DEPLOYMENT_NAME is not set in the environment.")  
  
print(f"Using deployment: {DEPLOYMENT_NAME}")

## 🔀 Choose evaluation split    
Set `EVAL_SPLIT` to either `"train"` or `"test"` before running the notebook.  

In [None]:
# ╔════════════════════════════════════════════════╗  
# ║ Cell 3 – Select evaluation split 🎛️           ║  
# ╚════════════════════════════════════════════════╝  
EVAL_SPLIT = "train"     # 👈 change to "train" as needed  
  
DATASET_DIR  = Path.cwd().parent / "datasets"  
dataset_path = DATASET_DIR / f"{EVAL_SPLIT}_dataset.csv"  
  
if not dataset_path.exists():  
    raise FileNotFoundError(f"Dataset file not found: {dataset_path}")  
  
df_all = pd.read_csv(dataset_path)  
print(f"Loaded {len(df_all):,} rows from {dataset_path.relative_to(Path.cwd().parent)}")

## ✍️  Prompt template    
Templates live in `all_class_files/prompts/`.    
Each template should contain `{TITLE}` and `{ABSTRACT}` placeholders (feel free to add others).  

In [None]:
# ╔════════════════════════════════════════════════╗  
# ║ Cell 4 – Prompt template builder ✍️           ║  
# ╚════════════════════════════════════════════════╝  
template_text = Path("prompts/v1.txt").read_text(encoding="utf-8")  
  
def build_prompt(title: str, abstract: str) -> str:  
    """Fill the template with the paper's title & abstract."""  
    return (  
        template_text  
        .replace("{TITLE}",    title.replace("\n", " ").strip())  
        .replace("{ABSTRACT}", abstract.replace("\n", " ").strip())  
    )

## 📂 Output directories    
Results are stored inside `all_class_files/outputs/`.

In [None]:
# ╔════════════════════════════════════════════════╗  
# ║ Cell 5 – Output directories 📂                 ║  
# ╚════════════════════════════════════════════════╝  
ROOT_DIR        = Path.cwd()  
OUTPUT_ROOT     = ROOT_DIR / "outputs"  
MODEL_DIR       = OUTPUT_ROOT / DEPLOYMENT_NAME / EVAL_SPLIT  
  
predictions_dir = MODEL_DIR / "predictions"  
unparsed_dir    = MODEL_DIR / "unparsed"  
  
for p in (predictions_dir, unparsed_dir):  
    p.mkdir(parents=True, exist_ok=True)  
  
print("Predictions will be saved to:", predictions_dir.resolve())  

## 💬 LLM call with retry    
Auto-retries on transient errors using an exponential back-off.

In [None]:
# ╔════════════════════════════════════════════════╗  
# ║ Cell 6 – LLM call with retry 💬                ║  
# ╚════════════════════════════════════════════════╝  
from tenacity import retry, stop_after_attempt, wait_random_exponential

# Models that need 'max_completion_tokens' instead of 'max_tokens'
ALT_TOKEN_PARAM_MODELS = {"o3", "o4-mini"}

@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(5))
def call_llm(prompt: str) -> str:
    """Call the chat completion endpoint and return raw text."""
    messages = [
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": "You are an AI assistant that helps people find information."
                }
            ],
        },
        {
            "role": "user",
            "content": [{"type": "text", "text": prompt}],
        },
    ]
    model_name = DEPLOYMENT_NAME.lower()
    params = {
        "model": DEPLOYMENT_NAME,
        "messages": messages,
        "temperature": 1,
        "top_p": 1,
    }

    # Use alternate param for o3/o4-mini, otherwise use default (4.1 style)
    if any(alt in model_name for alt in ALT_TOKEN_PARAM_MODELS):
        params["max_completion_tokens"] = 800
    else:
        params["max_tokens"] = 800

    try:
        resp = client.chat.completions.create(**params)
    except Exception as e:
        raise

    return resp.choices[0].message.content.strip()


## 🚀 Inference loop    
For each record we    
1. Build the prompt    
2. Call the model    
3. Parse the JSON response    
4. Save a per-instance JSON file    
  
Any response that fails JSON parsing is copied to the `unparsed/` folder for manual inspection.

In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 7 – Inference loop 🚀                     ║
# ╚════════════════════════════════════════════════╝
FEATURE_COLUMN_MAP = {
    "domain": "Social, Behavioural or Implementation Science?",
    "dmf_stage": (
        "DMF - Identify the issue and its context, assess risks and benefits, "
        "identify and analyze options, select a strategy, implement the strategy, "
        "monitor and evaluate results, involve interested and affected parties"
    ),
    "decision_type": (
        "DMF - Are the decisions regulatory, policy, or other? "
        "Please describe the “other” if applicable."
    ),
    "audience": "Audience",
    "methodology": "Methodology",
    "sample_size": "Sample Size",
}

EXPECTED_PRED_KEYS = ["classification_rationale"] + [
    k for feat in FEATURE_COLUMN_MAP for k in (feat, f"{feat}_rationale")
]

def _clean(val):
    if pd.isna(val):
        return None
    if isinstance(val, (np.generic, np.ndarray)):
        try:
            return val.item()
        except Exception:
            return str(val)
    return val

def _get_gt(row: pd.Series, header: str):
    header_lc = header.strip().lower()
    for col in row.index:
        if str(col).strip().lower() == header_lc:
            return row[col]
    return None

pred_rows: list[Dict[str, Any]] = []
unparsed_counter = 0

for _, row in tqdm(
    df_all.iterrows(),
    total=len(df_all),
    desc=f"Calling {DEPLOYMENT_NAME} ({EVAL_SPLIT})",
):
    prompt = build_prompt(row["Title"], row["abstract"])

    try:
        raw_response = call_llm(prompt)
    except Exception as e:
        raw_response = f"CALL_ERROR: {e}"

    try:
        parsed = json.loads(raw_response)
        pred_cls = parsed.get("classification", "").strip() or "ParseError"
        prediction_block = parsed
    except Exception as e:
        pred_cls = "ParseError"
        prediction_block = {"classification": "ParseError", "error": f"PARSE_ERROR: {e}"}
        unparsed_counter += 1

    for key in EXPECTED_PRED_KEYS:
        prediction_block.setdefault(key, None)

    ground_truth_block = {
        "classification": row.get("label", None),
        **{feat: _clean(_get_gt(row, hdr)) for feat, hdr in FEATURE_COLUMN_MAP.items()},
    }

    out_path = predictions_dir / f"{row['id']}.json"
    with open(out_path, "w", encoding="utf-8") as f:
        json.dump(
            {
                "ground_truth": ground_truth_block,
                "prediction": prediction_block,
                "raw_response": raw_response,
            },
            f,
            ensure_ascii=False,
            indent=2,
        )

    pred_rows.append(
        {
            "id": row["id"],
            "gt_cls": ground_truth_block["classification"],
            "pred_cls": pred_cls,
        }
    )

print(f"✅ Finished inference – unparsed responses: {unparsed_counter}")

## 📦 Organise unparsed responses    
Copy any files whose prediction was `"ParseError"` into `unparsed/` for quick triage.

In [None]:
# ╔════════════════════════════════════════════════╗
# ║ Cell 8 – Organise unparsed 📦                  ║
# ╚════════════════════════════════════════════════╝
df_pred = pd.DataFrame(pred_rows)

unparsed_df = df_pred[df_pred["pred_cls"] == "ParseError"]
for _, r in unparsed_df.iterrows():
    src = predictions_dir / f"{r['id']}.json"
    dst = unparsed_dir / src.name
    if src.exists():
        shutil.copy(src, dst)

print(f"Unparsed files copied: {len(unparsed_df)}")


### ✔️ Inference complete    
Rerun the notebook with `EVAL_SPLIT="train"` (or `"test"`) as needed to populate both splits.  

In [None]:
import os
from dotenv import load_dotenv
from openai import AzureOpenAI

# Load environment variables from .env as base
load_dotenv()
endpoint         = os.getenv("ENDPOINT_URL")
subscription_key = os.getenv("AZURE_OPENAI_API_KEY")
api_version      = os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview")

# List of deployments to test (update as needed)
deployments_to_test = ["gpt-4.1", "o3"]

O3_MODELS = {"o3", "gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-05-21", "o4-mini"}
FOUR_ONE_MODELS = {"gpt-4-1106-preview", "gpt-4.1", "gpt-4-0125-preview", "gpt-4-32k", "gpt-4-32k-0613"}

def get_model_type(deployment_name):
    name = deployment_name.lower()
    if any(k in name for k in O3_MODELS):
        return "o3"
    if any(k in name for k in FOUR_ONE_MODELS):
        return "4.1"
    if "o3" in name or "gpt-4o" in name:
        return "o3"
    return "4.1"

def test_deployment(deployment):
    model_type = get_model_type(deployment)
    print(f"\n--- Testing deployment: {deployment} ({model_type}) ---")
    client = AzureOpenAI(
        azure_endpoint=endpoint,
        api_key=subscription_key,
        api_version=api_version,
    )
    PROMPT = f"Say 'pong' if you see this. What model are you? (deployment: {deployment})"
    system_content = [{
        "type": "text",
        "text": "You are an AI assistant that helps people find information."
    }]
    messages = [
        {"role": "system", "content": system_content},
        {"role": "user",   "content": [{"type": "text", "text": PROMPT}]}
    ]
    params = {
        "model": deployment,
        "messages": messages,
        "temperature": 1,
        "top_p": 1,
        "frequency_penalty": 0,
        "presence_penalty": 0,
        "stop": None,
        "stream": False,
    }
    if model_type == "o3":
        params["max_completion_tokens"] = 16000
    else:
        params["max_tokens"] = 800
    try:
        completion = client.chat.completions.create(**params)
        print("RESPONSE:")
        print(completion.choices[0].message.content)
        print("SUCCESS\n")
        return True
    except Exception as e:
        print("ERROR:", e)
        return False

results = {}
for dep in deployments_to_test:
    result = test_deployment(dep)
    results[dep] = result

print("\n--- SUMMARY ---")
for dep, ok in results.items():
    print(f"{dep}: {'OK' if ok else 'FAIL'}")
