# 🚀🔎 Literature Screening – Multi-class Model Inference    
  
This notebook calls an Azure OpenAI deployment to produce **multi-class** predictions (and any free-text fields) for the `train` and `test` splits located in `../datasets/`.    
Outputs are saved inside `all_class_files/outputs/<MODEL_NAME>/<SPLIT>/` following the structure:  
  
```  
all_class_files/  
└── outputs/  
    └── <MODEL_NAME>/  
        ├── train/  
        │   ├── predictions/  
        │   └── unparsed/  
        └── test/  
            ├── predictions/  
            └── unparsed/  
```  
  
Run the notebook once for each split (`train` then `test`) to populate both folders.  

In [None]:
# ╔════════════════════════════════════════════════╗  
# ║ Cell 1 – Imports and environment setup 🌍       ║  
# ╚════════════════════════════════════════════════╝  
import os, json, random, shutil, time  
from pathlib import Path  
from typing import Dict, Any  
  
import numpy as np  
import pandas as pd  
  
from dotenv import load_dotenv  
from openai import AzureOpenAI  
from tenacity import retry, stop_after_attempt, wait_random_exponential  
from tqdm.auto import tqdm  
  
# -- Reproducibility -------------------------------------------------------- #  
SEED = 42  
random.seed(SEED)  
np.random.seed(SEED)  
  
# -- Environment ------------------------------------------------------------ #  
load_dotenv()    # expects ENDPOINT_URL, DEPLOYMENT_NAME, AZURE_OPENAI_API_KEY

In [None]:
# ╔════════════════════════════════════════════════╗  
# ║ Cell 2 – Azure OpenAI helper 🤖                ║  
# ╚════════════════════════════════════════════════╝  
def make_client() -> AzureOpenAI:  
    return AzureOpenAI(  
        api_key        = os.getenv("AZURE_OPENAI_API_KEY"),  
        azure_endpoint = os.getenv("ENDPOINT_URL"),  
        api_version    = os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview"),  
    )  
  
client          = make_client()  
DEPLOYMENT_NAME = os.getenv("DEPLOYMENT_NAME", "").strip()  
  
if not DEPLOYMENT_NAME:  
    raise ValueError("⚠️  DEPLOYMENT_NAME is not set in the environment.")  
  
print(f"Using deployment: {DEPLOYMENT_NAME}")

## 🔀 Choose evaluation split    
Set `EVAL_SPLIT` to either `"train"` or `"test"` before running the notebook.  

In [None]:
# ╔════════════════════════════════════════════════╗  
# ║ Cell 3 – Select evaluation split 🎛️           ║  
# ╚════════════════════════════════════════════════╝  
EVAL_SPLIT = "train"     # 👈 change to "train" as needed  
  
DATASET_DIR  = Path.cwd().parent / "datasets"  
dataset_path = DATASET_DIR / f"{EVAL_SPLIT}_dataset.csv"  
  
if not dataset_path.exists():  
    raise FileNotFoundError(f"Dataset file not found: {dataset_path}")  
  
df_all = pd.read_csv(dataset_path)  
print(f"Loaded {len(df_all):,} rows from {dataset_path.relative_to(Path.cwd().parent)}")

## ✍️  Prompt template    
Templates live in `all_class_files/prompts/`.    
Each template should contain `{TITLE}` and `{ABSTRACT}` placeholders (feel free to add others).  

In [None]:
# ╔════════════════════════════════════════════════╗  
# ║ Cell 4 – Prompt template builder ✍️           ║  
# ╚════════════════════════════════════════════════╝  
template_text = Path("prompts/v0.txt").read_text(encoding="utf-8")  
  
def build_prompt(title: str, abstract: str) -> str:  
    """Fill the template with the paper's title & abstract."""  
    return (  
        template_text  
        .replace("{TITLE}",    title.replace("\n", " ").strip())  
        .replace("{ABSTRACT}", abstract.replace("\n", " ").strip())  
    )

## 📂 Output directories    
Results are stored inside `all_class_files/outputs/`.

In [None]:
# ╔════════════════════════════════════════════════╗  
# ║ Cell 5 – Output directories 📂                 ║  
# ╚════════════════════════════════════════════════╝  
ROOT_DIR        = Path.cwd()  
OUTPUT_ROOT     = ROOT_DIR / "outputs"  
MODEL_DIR       = OUTPUT_ROOT / DEPLOYMENT_NAME / EVAL_SPLIT  
  
predictions_dir = MODEL_DIR / "predictions"  
unparsed_dir    = MODEL_DIR / "unparsed"  
  
for p in (predictions_dir, unparsed_dir):  
    p.mkdir(parents=True, exist_ok=True)  
  
print("Predictions will be saved to:", predictions_dir.resolve())  

## 💬 LLM call with retry    
Auto-retries on transient errors using an exponential back-off.

In [None]:
# ╔════════════════════════════════════════════════╗  
# ║ Cell 6 – LLM call with retry 💬                ║  
# ╚════════════════════════════════════════════════╝  
@retry(wait=wait_random_exponential(min=1, max=20), stop=stop_after_attempt(5))  
def call_llm(prompt: str) -> str:  
    """Call the chat completion endpoint and return raw text."""  
    messages = [  
        {  
            "role": "system",  
            "content": [  
                {  
                    "type": "text",  
                    "text": "You are an AI assistant that helps people find information."  
                }  
            ],  
        },  
        {  
            "role": "user",  
            "content": [{"type": "text", "text": prompt}],  
        },  
    ]  
  
    try:  
        resp = client.chat.completions.create(  
            model                 = DEPLOYMENT_NAME,  
            messages              = messages,  
            max_completion_tokens = 800,  
            temperature           = 1,  
            top_p                 = 1,  
        )  
    except Exception as e:  
        # Fall back for older api parameter name  
        if "max_completion_tokens" in str(e) and "unsupported_parameter" in str(e):  
            resp = client.chat.completions.create(  
                model       = DEPLOYMENT_NAME,  
                messages    = messages,  
                max_tokens  = 800,  
                temperature = 1,  
                top_p       = 1,  
            )  
        else:  
            raise  
  
    return resp.choices[0].message.content.strip()

## 🚀 Inference loop    
For each record we    
1. Build the prompt    
2. Call the model    
3. Parse the JSON response    
4. Save a per-instance JSON file    
  
Any response that fails JSON parsing is copied to the `unparsed/` folder for manual inspection.

In [None]:
# ╔════════════════════════════════════════════════╗  
# ║ Cell 7 – Inference loop 🚀                     ║  
# ╚════════════════════════════════════════════════╝  
pred_rows: list[Dict[str, Any]] = []  
unparsed_counter = 0  
  
for _, row in tqdm(df_all.iterrows(), total=len(df_all), desc=f"Calling {DEPLOYMENT_NAME} ({EVAL_SPLIT})"):  
    prompt      = build_prompt(row["Title"], row["abstract"])  
    raw_output  = ""  
    pred_label  = "ParseError"     # default until parsed  
    rationale   = ""  
    extras      = {}               # any additional keys from the LLM response  
  
    # ---------- 1) Call the model ---------- #  
    try:  
        raw_output = call_llm(prompt)  
    except Exception as e:  
        raw_output = f"CALL_ERROR: {e}"  
  
    # ---------- 2) Parse JSON --------------- #  
    try:  
        parsed       = json.loads(raw_output)  
        pred_label   = parsed.get("classification", "").strip()  
        rationale    = parsed.get("rationale", "").strip()  
  
        # capture any other fields that may be returned (free-text, scores, etc.)  
        extras = {k: v for k, v in parsed.items() if k not in ("classification", "rationale")}  
  
        if not pred_label:  
            raise ValueError("Empty classification value")  
  
    except Exception as e:  
        pred_label        = "ParseError"  
        rationale         = f"PARSE_ERROR: {e}"  
        extras            = {}  
        unparsed_counter += 1  
  
    # ---------- 3) Save individual JSON ----- #  
    out_path = predictions_dir / f"{row['id']}.json"  
    with open(out_path, "w", encoding="utf-8") as f:  
        json.dump(  
            {  
                "ground_truth": row.get("label", None),  
                "prediction"  : pred_label,  
                "rationale"   : rationale,  
                "extras"      : extras,  
                "raw_response": raw_output,  
            },  
            f,  
            ensure_ascii=False,  
            indent=2,  
        )  
  
    # ---------- 4) Collect for summary ------ #  
    pred_rows.append(  
        {  
            "id"          : row["id"],  
            "ground_truth": row.get("label", None),  
            "prediction"  : pred_label,  
            "rationale"   : rationale,
        }  
    )  
  
print(f"✅ Finished inference. Unparsed responses: {unparsed_counter}")  

## 📦 Organise unparsed responses    
Copy any files whose prediction was `"ParseError"` into `unparsed/` for quick triage.

In [None]:
# ╔════════════════════════════════════════════════╗  
# ║ Cell 8 – Organise unparsed 📦                  ║  
# ╚════════════════════════════════════════════════╝  
df_pred = pd.DataFrame(pred_rows)  
  
unparsed_df = df_pred[df_pred["prediction"] == "ParseError"]  
for _, r in unparsed_df.iterrows():  
    src = predictions_dir / f"{r['id']}.json"  
    dst = unparsed_dir / src.name  
    if src.exists():  
        shutil.copy(src, dst)  
  
print(f"Unparsed files copied: {len(unparsed_df)}")  

### ✔️ Inference complete    
Rerun the notebook with `EVAL_SPLIT="train"` (or `"test"`) as needed to populate both splits.  