# BC4SDGS (PSTW) USECASE EVALUATOR

Evaluate the BC cases over the 17 different SDGs.




In [6]:
import os, sys, importlib.util as u, importlib
VENDOR = os.path.expanduser("~/BC4SDGs/_vendor")
if VENDOR not in sys.path:
    sys.path.append(VENDOR)   # append to avoid shadowing env numpy/scipy
importlib.invalidate_caches()
print("vendor:", VENDOR, "exists:", os.path.isdir(VENDOR))
print("lime spec:", u.find_spec("lime") is not None)

from lime.lime_text import LimeTextExplainer
print("LIME import OK")



vendor: /delta/JRC/services/jupyter-home/combmar@delta.europa.eu/BC4SDGs/_vendor exists: True
lime spec: True
LIME import OK


## Set globals

In [7]:
import os

RELMMYY = "202507"

## TOKEN REQUIRED
GPT_JRC_TOKEN = "tbd"

# Set the environment variable
os.environ["GPT_JRC_TOKEN"] = GPT_JRC_TOKEN

# Retrieve and print the environment variable
#print("GPT Token = " + os.environ.get("GPT_JRC_TOKEN"))



## GET LLM List from API

In [8]:
import requests
import sqlite3
import pandas as pd
import logging
from openai import OpenAI
import os

# Set your API key
#openai.api_key = GPT_JRC_TOKEN

# Set your API key and base URL
api_key = GPT_JRC_TOKEN  # Replace with your actual API key
base_url = "https://api-gpt.jrc.ec.europa.eu/v1"  # Custom base URL

# Step 1: Set up OpenAI client securely
client = OpenAI(
    api_key=os.environ["GPT_JRC_TOKEN"],
    base_url="https://api-gpt.jrc.ec.europa.eu/v1"
)

def list_available_models():
    try:
        # Make a GET request to the models endpoint
        response = requests.get(
            f"{base_url}/models",
            headers={
                "Authorization": f"Bearer {api_key}"
            }
        )
        
        # Raise an error if the request was unsuccessful
        response.raise_for_status()

        # Parse the response JSON to get the list of models
        models = response.json()
        
        # Extract and print the model IDs
        model_ids = [model['id'] for model in models['data']]
        print("Available Models:")
        for model_id in model_ids:
            print(model_id)

    except requests.exceptions.RequestException as e:
        print(f"An error occurred: {e}")

# Call the function to list available models
list_available_models()


Available Models:
multilingual-e5-large
mistral-small-3.1-24b
react-agent-mistral-3.2
gpt-4o
moderation_multilingual
qwen-coder-2.5-base
mistral-small-3.2-24b
jina-embeddings-v2-base-en
qwen3-32b
qwen-coder-2.5-instruct
llama-3.3-70b-instruct
e5-large-v2


# USECASE SBC4DGS EVALUATOR

this code must do:

*  Load first batch of SDGs description (quanity is MAX= SDGS_BATCH) from the sdgs_path and setup the System Prompt including the extracted list of SDGs, using the "SDG-ID" available in the file from 1 to 17
  
* Setup the required columns in the pstw_data table for each SDG (Total = 17) in the database creating, if not present, two new columns for each SDG: the first named "SDG_"+[SDG-ID]+"_vote" and  the second named "SDG_"+[SDG-ID]+"_comment"
  
* loop over each rows in the pstw_data table evaluating each row using the "Name" and the "Description" field available
import sqlite3


MODEL_NAME = "llama-3.3-70b-instruct"
GPT_JRC_TOKEN


In [9]:
import os
import time
import logging
import sqlite3
from typing import List, Optional
import pandas as pd
from openai import OpenAI, RateLimitError
from datetime import datetime

# --------------------------------------------------------------------------- #
# Configuration
# --------------------------------------------------------------------------- #
RELMMYY = "202507"
DATA_DIR = "./data"
SDGS_PATH = os.path.join(DATA_DIR, "SDGs.csv")
DB_PATH = os.path.join(DATA_DIR, f"pstw_local_database_{RELMMYY}.db")

DEBUG = False
OVERWRITE_SDG = True          # process all rows if True; else only rows with SDG_1_vote IS NULL
BATCH_SIZE = 50
NUM_ROWS_TO_PROCESS: Optional[int] = None   # None = all eligible rows
TEMPERATURE = 0.1
# MODEL_NAME = "llama-3.3-70b-instruct"
MODEL_NAME = "gpt-4o" 
SDGS_TOTAL = 17
SDGS_BATCH = 17
LOG_FILE = "debug_logfile.log"

# --------------------------------------------------------------------------- #
# Logging
# --------------------------------------------------------------------------- #
logging.basicConfig(
    filename=LOG_FILE,
    level=logging.DEBUG if DEBUG else logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

# --------------------------------------------------------------------------- #
# OpenAI client
# --------------------------------------------------------------------------- #
api_key = os.environ["GPT_JRC_TOKEN"]
if not api_key:
    raise EnvironmentError("Set the GPT_JRC_TOKEN environment variable.")
client = OpenAI(api_key=api_key, base_url="https://api-gpt.jrc.ec.europa.eu/v1")

# --------------------------------------------------------------------------- #
# Helpers
# --------------------------------------------------------------------------- #
def load_sdgs(path: str, batch_size: int) -> pd.DataFrame:
    return pd.read_csv(path).head(batch_size)

def ensure_sdg_columns(cur: sqlite3.Cursor) -> None:
    for i in range(1, SDGS_TOTAL + 1):
        for suffix, col_type in (("vote", "INTEGER"), ("comment", "TEXT")):
            col = f"SDG_{i}_{suffix}"
            try:
                cur.execute(f"ALTER TABLE pstw_data ADD COLUMN {col} {col_type};")
            except sqlite3.OperationalError as exc:
                if "duplicate column name" not in str(exc).lower():
                    raise
    try:
        cur.execute("ALTER TABLE pstw_data ADD COLUMN sdg_overall INTEGER;")
    except sqlite3.OperationalError:
        pass

def parse_scores(raw: str) -> Optional[List[int]]:
    nums = [int(tok) for tok in raw.split() if tok.isdigit()]
    return nums if len(nums) == SDGS_TOTAL and all(1 <= n <= 10 for n in nums) else None

def call_llm(messages, max_tokens: int) -> str:
    delay = 1
    for _ in range(4):
        try:
            resp = client.chat.completions.create(
                model=MODEL_NAME,
                messages=messages,
                max_tokens=max_tokens,
                temperature=TEMPERATURE,
            )
            return resp.choices[0].message.content.strip()
        except (RateLimitError, Exception) as exc:
            logger.warning("LLM error: %s – retrying in %.1fs", exc, delay)
            time.sleep(delay)
            delay *= 2
    return ""

def evaluate_case(description: str, system_prompt: str) -> Optional[List[int]]:
    if not description:
        return None
    user_prompt = (
        f"{description}\n\nProvide 17 integers (1–10) separated by spaces — one per SDG, no extra text."
    )
    txt = call_llm(
        [{"role": "system", "content": system_prompt},
         {"role": "user", "content": user_prompt}],
        max_tokens=50,
    )
    return parse_scores(txt)

def generate_sdg_comments(description: str, scores: List[int]) -> List[str]:
    score_vec = " ".join(map(str, scores))
    prompt = (
        "You are a senior sustainability analyst. For the initiative below you have assigned a score for each SDG. "
        "For each SDG, write a short explanation (≤500 characters) justifying the numeric score. "
        "Avoid generic SDG definitions; reference concrete aspects of the initiative. "
        "Output exactly 17 lines, each formatted as `N: (score) comment`.\n\n"
        f"Initiative description:\n{description}\n\nScores: {score_vec}"
    )
    raw = call_llm([{"role": "user", "content": prompt}], max_tokens=1500)
    comments = ["" for _ in range(SDGS_TOTAL)]
    for line in raw.splitlines():
        if ":" in line:
            idx, rest = line.split(":", 1)
            if idx.strip().isdigit():
                i = int(idx.strip())
                if 1 <= i <= SDGS_TOTAL:
                    comments[i - 1] = rest.strip()[:500]
    return comments

def process_cases(system_prompt: str) -> None:
    with sqlite3.connect(DB_PATH) as conn:
        cur = conn.cursor()
        ensure_sdg_columns(cur)

        filter_clause = "" if OVERWRITE_SDG else "WHERE SDG_1_vote IS NULL"
        total = cur.execute(f"SELECT COUNT(*) FROM pstw_data {filter_clause}").fetchone()[0]
        if NUM_ROWS_TO_PROCESS:
            total = min(total, NUM_ROWS_TO_PROCESS)
        print(f"Rows to process (SDG scoring): {total}")
        logger.info("Rows to process (SDG scoring): %d", total)

        processed = 0
        last_id = 0  # for keyset pagination when OVERWRITE_SDG is True

        while processed < total:
            limit = min(BATCH_SIZE, total - processed)
            if OVERWRITE_SDG:
                rows = cur.execute(
                    "SELECT id, name, description FROM pstw_data "
                    "WHERE id > ? ORDER BY id LIMIT ?",
                    (last_id, limit),
                ).fetchall()
            else:
                rows = cur.execute(
                    "SELECT id, name, description FROM pstw_data "
                    "WHERE SDG_1_vote IS NULL ORDER BY id LIMIT ?",
                    (limit,),
                ).fetchall()

            if not rows:
                break

            for row_id, name, desc in rows:
                text = desc or name
                scores = evaluate_case(text, system_prompt)
                if not scores:
                    continue
                comments = generate_sdg_comments(text, scores)
                overall = round(sum(scores) / SDGS_TOTAL)

                vote_set = ", ".join([f"SDG_{i}_vote = ?" for i in range(1, SDGS_TOTAL + 1)])
                cmt_set = ", ".join([f"SDG_{i}_comment = ?" for i in range(1, SDGS_TOTAL + 1)])
                sql = f"UPDATE pstw_data SET {vote_set}, {cmt_set}, sdg_overall = ? WHERE id = ?"
                cur.execute(sql, [*scores, *comments, overall, row_id])

            conn.commit()
            processed += len(rows)
            if OVERWRITE_SDG and rows:
                last_id = rows[-1][0]
            print(f"{processed}/{total} done")
            logger.info("Processed %d/%d", processed, total)

        print("SDG scoring complete.")
        logger.info("SDG scoring complete.")

def main() -> None:
    start_dt = datetime.now()
    print(f"=== SDG scoring started: {start_dt.isoformat(sep=' ', timespec='seconds')} ===")
    logger.info("Process started at %s", start_dt.isoformat())

    sdg_df = load_sdgs(SDGS_PATH, SDGS_BATCH)
    sdg_ref = "\n".join(sdg_df["Keywords_Merged"].fillna("").astype(str).tolist())
    sys_prompt = (
        "You are a senior sustainability analyst. Score each initiative from 1 to 10 for all 17 SDGs. "
        "Use these SDG keyword sets as reference:\n" + sdg_ref
    )

    process_cases(sys_prompt)

    end_dt = datetime.now()
    elapsed = end_dt - start_dt
    print(f"=== Finished: {end_dt.isoformat(sep=' ', timespec='seconds')} (elapsed {elapsed}) ===")
    logger.info("Process finished at %s (elapsed %s)", end_dt.isoformat(), str(elapsed))

if __name__ == "__main__":
    main()



=== SDG scoring started: 2025-09-08 19:39:31 ===
Rows to process (SDG scoring): 306
50/306 done
100/306 done
150/306 done
200/306 done
250/306 done
300/306 done
306/306 done
SDG scoring complete.
=== Finished: 2025-09-08 20:10:26 (elapsed 0:30:54.954109) ===


##  SDGs Evaluator Results comparator




In [5]:
# Merge BC4SDGs evaluator runs into one CSV with refined ordering
# - Join key: PSTW
# - Left-join on master
# - Inputs in ./eval_results/
# - Output: ./eval_results/bc4sdgs_202507_merged.csv
# - SDG_k_vote_rev stays next to votes
# - "Notes about observing the votes assigned" and
#   "Potential SDGs not considered by the Automated Process"
#   are placed at the end after SDG_Overall_Votes_Summary

from pathlib import Path
import re
import pandas as pd

FOLDER = Path("./eval_results")
MASTER_NAME = "bc4sdgs_descr_202507_master.csv"
KEYWORDS_PREFIX = "bc4sdgs_keywords_"
OUTPUT_NAME = "bc4sdgs_202507_merged.csv"
JOIN_KEY = "pstw_id"

# Special trailing columns
SUMMARY_COL = "SDG_Overall_Votes_Summary"
NOTES_COL = "Notes about observing the votes assigned"
POTENTIAL_COL = "Potential SDGs not considered by the Automated Process"

SDG_IDS = list(range(1, 18))
VOTE_COLS = [f"SDG_{k}_vote" for k in SDG_IDS]
VOTE_REV_COLS = [f"SDG_{k}_vote_rev" for k in SDG_IDS]  # human-in-the-loop
COMM_COLS = [f"SDG_{k}_comment" for k in SDG_IDS]
SDG_COLS = VOTE_COLS + VOTE_REV_COLS + COMM_COLS

def read_csv_any(path: Path) -> pd.DataFrame:
    for enc in ("utf-8", "utf-8-sig", "cp1252", "latin-1"):
        try:
            return pd.read_csv(path, dtype=str, encoding=enc)
        except Exception:
            pass
    return pd.read_csv(path, dtype=str)

def normalize_columns(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df.columns = [str(c).strip().lstrip("\ufeff") for c in df.columns]
    return df

def postfix_from_filename(fname: str) -> str:
    base = Path(fname).name
    if not base.startswith(KEYWORDS_PREFIX) or not base.endswith(".csv"):
        return None
    core = base[len(KEYWORDS_PREFIX):-4]
    core = re.sub(r"\W+", "_", core).strip("_")
    return core

def ensure_master_unique(df_master: pd.DataFrame):
    if JOIN_KEY not in df_master.columns:
        raise KeyError(f"Master missing join key '{JOIN_KEY}'.")
    if df_master[JOIN_KEY].duplicated(keep=False).any():
        raise ValueError(f"Master '{JOIN_KEY}' contains duplicates. Uniqueness required.")

def load_master(master_path: Path) -> pd.DataFrame:
    dfm = read_csv_any(master_path)
    dfm = normalize_columns(dfm)
    dfm[JOIN_KEY] = dfm[JOIN_KEY].astype(str)
    ensure_master_unique(dfm)
    return dfm

def select_and_suffix_eval(df_eval: pd.DataFrame, postfix: str):
    have_cols = [c for c in (VOTE_COLS + COMM_COLS) if c in df_eval.columns]
    keep = [JOIN_KEY] + have_cols
    slim = df_eval[keep].drop_duplicates(subset=[JOIN_KEY], keep="first").copy()
    renamer = {c: f"{c}__{postfix}" for c in have_cols}
    slim = slim.rename(columns=renamer)
    return slim, have_cols

def build_column_order(master_cols, all_cols):
    master_cols = list(master_cols)
    all_cols = list(all_cols)

    # Trailing group: ensure these appear at the very end, in this order if present
    trailing = [c for c in [SUMMARY_COL, NOTES_COL, POTENTIAL_COL] if c in all_cols]

    # Prefix: non-SDG master columns, excluding trailing ones
    def is_sdg_col(c: str) -> bool:
        return c.startswith("SDG_")

    prefix = [JOIN_KEY] + [
        c for c in master_cols
        if c != JOIN_KEY and not is_sdg_col(c) and c not in trailing
    ]

    # SDG groups
    ordered = []
    for k in SDG_IDS:
        base_vote = f"SDG_{k}_vote"
        base_rev = f"SDG_{k}_vote_rev"
        base_comm = f"SDG_{k}_comment"

        if base_vote in all_cols:   # master base vote
            ordered.append(base_vote)
        if base_rev in all_cols:    # master human-in-the-loop vote
            ordered.append(base_rev)

        # per-file votes for this SDG
        vote_suff = sorted([c for c in all_cols if c.startswith(base_vote + "__")])
        ordered.extend(vote_suff)

        if base_comm in all_cols:   # master comment
            ordered.append(base_comm)

        # per-file comments for this SDG
        comm_suff = sorted([c for c in all_cols if c.startswith(base_comm + "__")])
        ordered.extend(comm_suff)

    # Anything not yet included and not trailing
    included = set(prefix) | set(ordered) | set(trailing)
    remainder = [c for c in all_cols if c not in included]

    # Final order
    return prefix + ordered + remainder + trailing

def main():
    if not FOLDER.exists():
        raise FileNotFoundError(f"Folder not found: {FOLDER}")

    master_path = FOLDER / MASTER_NAME
    if not master_path.exists():
        raise FileNotFoundError(f"Master not found: {master_path}")

    df_master = load_master(master_path)
    out_df = df_master.copy()

    files = sorted([p for p in FOLDER.glob(KEYWORDS_PREFIX + "*.csv")])
    seen_postfix = set()
    skipped_dupe = []

    for path in files:
        postfix = postfix_from_filename(path.name)
        if postfix is None:
            continue
        if postfix in seen_postfix:
            skipped_dupe.append((path.name, postfix))
            continue

        df_eval = read_csv_any(path)
        df_eval = normalize_columns(df_eval)
        if JOIN_KEY not in df_eval.columns:
            print(f"[WARN] {path.name} skipped: missing join key '{JOIN_KEY}'.")
            continue

        df_eval[JOIN_KEY] = df_eval[JOIN_KEY].astype(str)
        slim, have_cols = select_and_suffix_eval(df_eval, postfix)
        if not have_cols:
            print(f"[WARN] {path.name} skipped: no SDG vote/comment columns found.")
            continue

        out_df = out_df.merge(slim, how="left", on=JOIN_KEY)
        seen_postfix.add(postfix)

    final_order = build_column_order(list(df_master.columns), list(out_df.columns))
    out_df = out_df.reindex(columns=final_order)

    output_path = FOLDER / OUTPUT_NAME
    out_df.to_csv(output_path, index=False, encoding="utf-8")

    print("=== Merge Report ===")
    print(f"Rows in master: {len(df_master)}")
    print(f"Processed files: {len(seen_postfix)}")
    if skipped_dupe:
        print("Skipped due to duplicate postfix:")
        for fname, pf in skipped_dupe:
            print(f"- {fname} (postfix '{pf}')")
    print(f"Output: {output_path}")
    print(f"Total columns: {out_df.shape[1]}")

if __name__ == "__main__":
    pd.set_option("display.width", 160)
    pd.set_option("display.max_columns", 200)
    main()



=== Merge Report ===
Rows in master: 307
Processed files: 2
Output: eval_results/bc4sdgs_202507_merged.csv
Total columns: 190


## Evolved USECASE SBC4DGS EVALUATOR with LIME enabled

Local Interpretable Model-agnostic Explanations.  
Explains a single prediction by fitting a simple, sparse surrogate around that instance.

**Reference**  
Ribeiro, M. T., Singh, S., Guestrin, C. (2016). “Why Should I Trust You?”: Explaining the predictions of any classifier. *KDD ‘16*, 1135–1144. [doi:10.1145/2939672.2939778](https://doi.org/10.1145/2939672.2939778) • [PDF](https://dl.acm.org/doi/pdf/10.1145/2939672.2939778)

---

### How it is applied here

- The LLM classifier is wrapped as a `predict_proba` for LIME.  
- LIME perturbs the text and fits a local linear surrogate in bag-of-words space.  
- HTML shows **class k vs NOT k** for the top predicted score.  
- Explanations are saved per row and per SDG.

**Files written**
- `./lime/row{ID}_sdg{N}.html` — visual explanation  
- `./lime/row{ID}_sdg{N}.json` — `{row_id, sdg, probs[1..10], features}`

---

### Configuration (top of script)

- `ENABLE_LIME`: on/off.  
- `LIME_OUTPUT_DIR`: output folder (default `./lime`).  
- `LIME_NUM_FEATURES`: number of tokens shown.  
- `LIME_NUM_SAMPLES`: perturbed samples per explanation.  
- `LIME_SEED`: reproducibility.  
- `LIME_ROWS_LIMIT`: max rows to explain (does not limit scoring).  
- `LIME_SDG_LIST`: SDGs to explain, e.g. `(1, 8, 13)`.  
- `LIME_SMOOTH_EPS`: smoothing for class probabilities in the adapter.  
- `LIME_PERTURB_BATCH`, `TOKENS_PER_ITEM`, `MAX_TOKENS_CAP`: batching and output budget guards.

---

### Integration notes

- Full script preserves your existing LLM scoring.  
- Explanations run in addition to scoring; scoring still processes all pending rows.  
- Heavy caching and limits reduce API load.  
- `call_llm` supports temperature override; core scoring uses `TEMPERATURE`.  
- Environment token: `GPT_JRC_TOKEN`.  
- Fixed `OVERWRITE_SDG = TRUE` → `True`.

---

### Practical guidance

- Increase `LIME_NUM_SAMPLES` for stability.  
- Adjust `LIME_NUM_FEATURES` for more/less detail.  
- Use `LIME_SDG_LIST = tuple(range(1, SDGS_TOTAL+1))` to explain all SDGs.  
- `LIME_ROWS_LIMIT` caps explanations only; raise to inspect more rows.




# LIME: basics

- **Black box.** Treat the model as \(f(x)\) that outputs class scores.  
- **Perturb.** Create many variants of the single input by randomly removing words.  
- **Query.** Run \(f(z_i)\) on each perturbed text.  
- **Locality weights.** Weight each \(z_i\) by proximity to the original: \(\pi_x(z_i)=\exp(-d(x,z_i)^2/\sigma^2)\).  
- **Fit surrogate.** For class \(k\), fit a sparse linear model \(g\) on the weighted samples:  
  \[
  \arg\min_g \sum_i \pi_x(z_i)\,\big(f_k(z_i)-g(z_i)\big)^2 + \lambda\lVert g\rVert_1
  \]
- **Explain.** Coefficients of \(g\) are token “weights”:  
  - positive → pushes prediction **toward \(k\)**  
  - negative → pushes **away from \(k\)** (“NOT \(k\)”)  
  - these are not probabilities.

---

## In this pipeline

- **Model \(f\):** the LLM scorer wrapped as a 10-class classifier for a single SDG.  
- **`probs` (in JSON):** a **smoothed one-hot** from the predicted integer score:  
  - about `0.9` on the chosen score,  
  - about `0.1/9` on all others,  
  - controlled by `LIME_SMOOTH_EPS`.  
  These are surrogate probabilities, not calibrated confidence.  
- **`features`:** top `[token, weight]` pairs from the local linear surrogate.  
- **HTML “k vs NOT k”:** LIME explains one class at a time; the opposite side is **NOT k**.

---

## Not keyword matching

LIME is not simple text matching. It measures how the model’s output changes when words are removed, then fits a **local linear approximation** around the specific input.

---

## Tuning knobs

- **Stability:** increase `LIME_NUM_SAMPLES`.  
- **Detail:** adjust `LIME_NUM_FEATURES`.  
- **Text prep:** lowercase/clean; optionally add a custom tokenizer to include bigrams.  
- **Domain signal:** enrich inputs with SDG-specific terms, or switch to a Monte-Carlo adapter to estimate less “one-hot” class probabilities.


In [23]:
"""
SDG scoring script — PSTW-safe + robust incremental + clear startup summary
===========================================================================

- Writes only SDG_n_vote / SDG_n_comment (+ optional sdg_overall).
- Timestamped prints.
- TRUE incremental when OVERWRITE_SDG=False:
  * Selects any row with any missing/invalid SDG vote/comment.
  * Pages by ORDER BY id with a moving cursor (no OFFSET).
  * Counting uses predicate only; paging uses id cursor.
- LIME optional; LIME_ROWS_LIMIT limits explanations only, not scoring.
"""

from __future__ import annotations

import os
import re
import json
import math
import time
import logging
import sqlite3
from typing import List, Optional, Tuple, Sequence
from collections import OrderedDict
from datetime import datetime, timedelta

import numpy as np
import pandas as pd
from openai import OpenAI, RateLimitError

# ----- Optional LIME -----
try:
    from lime.lime_text import LimeTextExplainer
    LIME_AVAILABLE = True
except Exception:
    LIME_AVAILABLE = False

# --------------------------------------------------------------------------- #
# Configuration
# --------------------------------------------------------------------------- #
RELMMYY = "202507"
DATA_DIR = "./data"
SDGS_PATH = os.path.join(DATA_DIR, "SDGs.csv")
DB_PATH = os.path.join(DATA_DIR, f"pstw_local_database_{RELMMYY}.db")

DEBUG = False
OVERWRITE_SDG = True           # True = recompute all rows; False = only incomplete rows
BATCH_SIZE = 50
NUM_ROWS_TO_PROCESS = 2_000
TEMPERATURE = 0.3
MODEL_NAME = "llama-3.3-70b-instruct"
SDGS_TOTAL = 17
SDGS_BATCH = 6
LOG_FILE = "debug_logfile.log"

# Printing
PRINT_FLUSH = True
def _ts() -> str: return datetime.now().isoformat(sep=" ", timespec="seconds")
def p(msg: str) -> None: print(f"{_ts()} | {msg}", flush=PRINT_FLUSH)

# LIME controls
ENABLE_LIME = True
LIME_OUTPUT_DIR = os.path.abspath(os.getenv("LIME_OUTPUT_DIR", "./lime"))
LIME_NUM_FEATURES = 10
LIME_NUM_SAMPLES = 250
LIME_SEED = 0
LIME_ROWS_LIMIT = 10            # limits explanations only
LIME_SDG_LIST: Sequence[int] = (1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17)
LIME_SMOOTH_EPS = 0.10
LIME_PERTURB_BATCH = 8
TOKENS_PER_ITEM = 42
MAX_TOKENS_CAP = 2800

# LLM throttle
MIN_INTERVAL = 0.35
_last_call_ts = 0.0

# --------------------------------------------------------------------------- #
# Logging
# --------------------------------------------------------------------------- #
logging.basicConfig(
    filename=LOG_FILE,
    level=logging.DEBUG if DEBUG else logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

# --------------------------------------------------------------------------- #
# OpenAI client
# --------------------------------------------------------------------------- #
api_key = os.environ.get("GPT_JRC_TOKEN")
if not api_key:
    raise EnvironmentError("Set the GPT_JRC_TOKEN environment variable.")
client = OpenAI(api_key=api_key, base_url="https://api-gpt.jrc.ec.europa.eu/v1")

def _throttle():
    global _last_call_ts
    now = time.perf_counter()
    wait = _last_call_ts + MIN_INTERVAL - now
    if wait > 0: time.sleep(wait)
    _last_call_ts = time.perf_counter()

# --------------------------------------------------------------------------- #
# Helpers
# --------------------------------------------------------------------------- #
def load_sdgs(path: str, batch_size: int) -> pd.DataFrame:
    return pd.read_csv(path).head(batch_size)

def ensure_sdg_columns(cur: sqlite3.Cursor) -> None:
    for i in range(1, SDGS_TOTAL + 1):
        for suffix, col_type in (("vote", "INTEGER"), ("comment", "TEXT")):
            col = f"SDG_{i}_{suffix}"
            try:
                cur.execute(f"ALTER TABLE pstw_data ADD COLUMN {col} {col_type};")
            except sqlite3.OperationalError as exc:
                if "duplicate column name" not in str(exc).lower():
                    raise
    try:
        cur.execute("ALTER TABLE pstw_data ADD COLUMN sdg_overall INTEGER;")
    except sqlite3.OperationalError:
        pass

def _ensure_indexes(cur: sqlite3.Cursor):
    cur.execute("CREATE INDEX IF NOT EXISTS idx_pstw_id ON pstw_data(id)")

def parse_scores(raw: str) -> Optional[List[int]]:
    nums = [int(tok) for tok in raw.split() if tok.isdigit()]
    return nums if len(nums) == SDGS_TOTAL and all(1 <= n <= 10 for n in nums) else None

def call_llm(messages, max_tokens: int, temperature: Optional[float] = None) -> str:
    delay = 1.0
    for attempt in range(1, 5):
        try:
            _throttle()
            resp = client.chat.completions.create(
                model=MODEL_NAME,
                messages=messages,
                max_tokens=max_tokens,
                temperature=(TEMPERATURE if temperature is None else temperature),
            )
            return (resp.choices[0].message.content or "").strip()
        except (RateLimitError, Exception) as exc:
            logger.warning("LLM error (attempt %d): %s – retrying in %.1fs", attempt, exc, delay)
            time.sleep(delay); delay *= 2
    return ""

def evaluate_case(description: str, system_prompt: str) -> Optional[List[int]]:
    if not description: return None
    user_prompt = (
        "Return only 17 integers 1..10 separated by spaces. No text.\n\n"
        f"{description}\n\n17 numbers:"
    )
    txt = call_llm(
        [{"role": "system", "content": system_prompt},
         {"role": "user", "content": user_prompt}],
        max_tokens=60, temperature=0.0
    )
    return parse_scores(txt)

def generate_sdg_comments(description: str, scores: List[int]) -> List[str]:
    score_vec = " ".join(map(str, scores))
    prompt = (
        "You are a senior sustainability analyst. For the initiative below you assigned a score for each SDG.\n"
        "For each SDG 1..17 write ≤500 chars explaining the numeric score. Avoid generic SDG definitions.\n"
        "Output exactly 17 lines as: N: (score) comment\n\n"
        f"Initiative:\n{description}\n\nScores: {score_vec}"
    )
    raw = call_llm([{"role": "user", "content": prompt}], max_tokens=1500, temperature=TEMPERATURE)
    comments = ["" for _ in range(SDGS_TOTAL)]
    for line in raw.splitlines():
        if ":" in line:
            idx, rest = line.split(":", 1)
            if idx.strip().isdigit():
                i = int(idx.strip())
                if 1 <= i <= SDGS_TOTAL:
                    comments[i - 1] = rest.strip()[:500]
    return comments

# --------------------------------------------------------------------------- #
# Batched scoring for LIME perturbations
# --------------------------------------------------------------------------- #
SCORE_CACHE: "OrderedDict[str, List[int]]" = OrderedDict()
SCORE_CACHE_MAX = 20000

def _cache_put(text: str, scores: Optional[List[int]]):
    if scores is None: return
    SCORE_CACHE[text] = scores
    SCORE_CACHE.move_to_end(text)
    if len(SCORE_CACHE) > SCORE_CACHE_MAX:
        SCORE_CACHE.popitem(last=False)

_idx_line = re.compile(r"^\s*(\d+)\s*:\s*(.*)$")
_int_tok = re.compile(r"\b(\d{1,2})\b")

def _parse_batch_output(chunk_len: int, txt: str) -> List[Optional[List[int]]]:
    out: List[Optional[List[int]]] = [None] * chunk_len
    if not txt: return out
    for line in txt.splitlines():
        m = _idx_line.match(line)
        if not m: continue
        idx = int(m.group(1)) - 1
        if 0 <= idx < chunk_len:
            nums = [int(t) for t in _int_tok.findall(m.group(2))]
            if len(nums) == SDGS_TOTAL and all(1 <= n <= 10 for n in nums):
                out[idx] = nums
    return out

def evaluate_cases_batch(texts: List[str], system_prompt: str, batch_size: int = LIME_PERTURB_BATCH) -> List[Optional[List[int]]]:
    out: List[Optional[List[int]]] = [None] * len(texts)
    for start in range(0, len(texts), batch_size):
        chunk = texts[start:start + batch_size]
        items = []
        for i, t in enumerate(chunk):
            s = " ".join(t.split())[:800]
            items.append(f"{i+1}. {s}")
        prompt_user = (
            "For each item, output one line with exactly 17 integers 1..10 separated by spaces.\n"
            "Format strictly: <index>: n1 n2 ... n17\n"
            "No explanations. No extra lines.\n\n" + "\n".join(items) + "\n\nOutput:"
        )
        max_tokens = min(MAX_TOKENS_CAP, TOKENS_PER_ITEM * len(chunk) + 120)
        txt = call_llm(
            [{"role": "system", "content": system_prompt},
             {"role": "user", "content": prompt_user}],
            max_tokens=max_tokens, temperature=0.0
        )
        parsed_list = _parse_batch_output(len(chunk), txt)
        for i, scores in enumerate(parsed_list):
            out[start + i] = scores
    return out

class BatchedSDGAdapter:
    """Return P(score∈{1..10}|text) for a given SDG index using cached 17-score vectors."""
    def __init__(self, sdg_idx: int, system_prompt: str, smooth_eps: float = 0.10):
        self.sdg_idx = sdg_idx
        self.system_prompt = system_prompt
        self.smooth_eps = smooth_eps

    def predict_proba(self, texts: List[str]):
        missing = [t for t in texts if t not in SCORE_CACHE]
        if missing:
            p(f"[LIME] filling cache for {len(missing)} texts")
            results = evaluate_cases_batch(missing, self.system_prompt, batch_size=LIME_PERTURB_BATCH)
            for t, scores in zip(missing, results):
                _cache_put(t, scores)

        K = 10
        rows = []
        for t in texts:
            scores = SCORE_CACHE.get(t)
            if not scores:
                rows.append([1.0 / K] * K)
                continue
            s = int(max(1, min(10, scores[self.sdg_idx - 1])))
            vec = [self.smooth_eps / (K - 1)] * K
            vec[s - 1] = 1.0 - self.smooth_eps
            rows.append(vec)
        arr = np.array(rows, dtype=float)
        if not np.allclose(arr.sum(axis=1), 1.0, atol=1e-6):
            for i in range(arr.shape[0]):
                s = arr[i, :].sum()
                if s > 0: arr[i, :] /= s
        return arr

# --------------------------------------------------------------------------- #
# LIME wrappers
# --------------------------------------------------------------------------- #
def _fallback_html(row_id: int, sdg_idx: int, err: str) -> str:
    return (
        "<html><body>"
        f"<h3>LIME explanation unavailable</h3>"
        f"<p>Row {row_id} SDG {sdg_idx}</p>"
        f"<pre style='white-space:pre-wrap'>{err}</pre>"
        "</body></html>"
    )

def lime_explain_single_sdg(
    text: str,
    row_id: int,
    sdg_idx: int,
    sdg_title: str,
    sdg_desc: str,
    system_prompt: str,
    num_features: int,
    num_samples: int,
    seed: int,
):
    adapter = BatchedSDGAdapter(sdg_idx=sdg_idx, system_prompt=system_prompt, smooth_eps=LIME_SMOOTH_EPS)
    explainer = LimeTextExplainer(class_names=[str(i) for i in range(1, 11)], random_state=seed)
    p(f"[LIME] Start row {row_id} SDG {sdg_idx} | samples={num_samples} (batched, cached)")
    t0 = time.perf_counter()
    try:
        exp = explainer.explain_instance(
            text_instance=text,
            classifier_fn=adapter.predict_proba,
            num_features=num_features,
            num_samples=num_samples,
        )
        probs = adapter.predict_proba([text])[0]
        top_idx = int(np.argmax(probs))
        html = exp.as_html(predict_proba=False)  # avoid label arg bug
        if not html or not html.strip():
            html = _fallback_html(row_id, sdg_idx, "Empty HTML from LIME.")
        feats = exp.as_list()
        p(f"[LIME] Done  row {row_id} SDG {sdg_idx} in {time.perf_counter()-t0:.1f}s "
          f"| feats={len(feats)} | html_len={len(html)}")
        return html, top_idx, probs.tolist(), feats
    except Exception as e:
        err = repr(e) if e is not None else "Unknown error"
        p(f"[LIME] exception row {row_id} SDG {sdg_idx}: {err}")
        html = _fallback_html(row_id, sdg_idx, err)
        probs = [1.0/10]*10
        return html, 0, probs, []

def save_text(path: str, content: str):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, "w", encoding="utf-8") as f:
        f.write(content)

# --------------------------------------------------------------------------- #
# Incomplete-row predicate + diagnostics
# --------------------------------------------------------------------------- #
def _build_incomplete_predicate() -> str:
    """Any missing/invalid vote or empty comment marks a row as incomplete."""
    vote_terms = [f"(SDG_{i}_vote IS NULL OR SDG_{i}_vote NOT BETWEEN 1 AND 10)"
                  for i in range(1, SDGS_TOTAL + 1)]
    cmt_terms  = [f"(SDG_{i}_comment IS NULL OR TRIM(SDG_{i}_comment)='')"
                  for i in range(1, SDGS_TOTAL + 1)]
    overall    = "(sdg_overall IS NULL OR sdg_overall NOT BETWEEN 1 AND 10)"
    return "(" + " OR ".join(vote_terms + cmt_terms + [overall]) + ")"

def _startup_summary(cur: sqlite3.Cursor, base_cond: str) -> None:
    tot = cur.execute("SELECT COUNT(*) FROM pstw_data").fetchone()[0]
    pend = cur.execute(f"SELECT COUNT(*) FROM pstw_data WHERE {base_cond}").fetchone()[0]
    done = tot - pend
    sample = cur.execute(
        f"SELECT id FROM pstw_data WHERE {base_cond} ORDER BY id LIMIT 10"
    ).fetchall()
    sample_ids = [r[0] for r in sample]
    min_id = cur.execute("SELECT MIN(id) FROM pstw_data").fetchone()[0]
    max_id = cur.execute("SELECT MAX(id) FROM pstw_data").fetchone()[0]
    p(f"Database: {DB_PATH}")
    p(f"Rows total: {tot} | already done: {done} | pending: {pend}")
    p(f"ID range: [{min_id}, {max_id}] | first pending ids: {sample_ids if sample_ids else '[]'}")
    if not OVERWRITE_SDG and pend == 0:
        p("Note: OVERWRITE_SDG=False and predicate finds no pending rows. "
          "Set OVERWRITE_SDG=True to recompute all rows.")

# --------------------------------------------------------------------------- #
# Main processing
# --------------------------------------------------------------------------- #
def _fmt_td(seconds: float) -> str:
    if not math.isfinite(seconds): return "inf"
    return str(timedelta(seconds=int(seconds)))

def process_cases(system_prompt: str, sdg_meta: List[Tuple[int, str, str]]) -> None:
    start_ts = time.perf_counter()
    with sqlite3.connect(DB_PATH) as conn:
        cur = conn.cursor()
        ensure_sdg_columns(cur)
        _ensure_indexes(cur)

        base_cond = "1=1" if OVERWRITE_SDG else _build_incomplete_predicate()
        _startup_summary(cur, base_cond)

        # total remaining regardless of paging cursor
        total = cur.execute(f"SELECT COUNT(*) FROM pstw_data WHERE {base_cond}").fetchone()[0]
        total = min(total, NUM_ROWS_TO_PROCESS)
        p(f"Rows to process (SDG scoring): {total} (predicate-only)")
        logger.info("Rows to process (SDG scoring): %d", total)

        processed = 0
        lime_rows_done = 0
        last_id = -1  # paging cursor

        if ENABLE_LIME and LIME_AVAILABLE:
            os.makedirs(LIME_OUTPUT_DIR, exist_ok=True)

        while processed < total:
            limit = min(BATCH_SIZE, total - processed)
            rows = cur.execute(
                f"""
                SELECT id, name, description
                FROM pstw_data
                WHERE {base_cond} AND id > ?
                ORDER BY id
                LIMIT ?
                """,
                (last_id, limit),
            ).fetchall()
            if not rows:
                break

            batch_t0 = time.perf_counter()

            for (row_id, name, desc) in rows:
                last_id = row_id  # advance cursor
                row_t0 = time.perf_counter()
                text = (desc or "").strip() or (name or "")
                if not text:
                    p(f"[{processed+1}/{total}] id={row_id}: empty text, skipped.")
                    processed += 1
                    continue

                p(f"[{processed+1}/{total}] id={row_id}: scoring…")
                scores = evaluate_case(text, system_prompt)
                if not scores:
                    p(f"[{processed+1}/{total}] id={row_id}: parsing failed.")
                    processed += 1
                    continue

                comments = generate_sdg_comments(text, scores)
                overall = round(sum(scores) / SDGS_TOTAL)

                vote_set = ", ".join([f"SDG_{i}_vote = ?" for i in range(1, SDGS_TOTAL + 1)])
                cmt_set  = ", ".join([f"SDG_{i}_comment = ?" for i in range(1, SDGS_TOTAL + 1)])
                cur.execute(
                    f"UPDATE pstw_data SET {vote_set}, {cmt_set}, sdg_overall = ? WHERE id = ?",
                    [*scores, *comments, overall, row_id],
                )

                # LIME (explanations only)
                if ENABLE_LIME and LIME_AVAILABLE and lime_rows_done < LIME_ROWS_LIMIT:
                    for sdg_idx in LIME_SDG_LIST:
                        i0 = sdg_idx - 1
                        sdg_title = sdg_meta[i0][1]; sdg_desc = sdg_meta[i0][2]
                        html, lbl_idx, probs, feats = lime_explain_single_sdg(
                            text=text,
                            row_id=row_id,
                            sdg_idx=sdg_idx,
                            sdg_title=sdg_title,
                            sdg_desc=sdg_desc,
                            system_prompt=system_prompt,
                            num_features=LIME_NUM_FEATURES,
                            num_samples=LIME_NUM_SAMPLES,
                            seed=LIME_SEED,
                        )
                        out_base = os.path.join(LIME_OUTPUT_DIR, f"row{row_id}_sdg{sdg_idx}")
                        save_text(out_base + ".html", html)
                        with open(out_base + ".json", "w", encoding="utf-8") as jf:
                            json.dump({"row_id": row_id, "sdg": sdg_idx, "probs": probs, "features": feats}, jf)
                    lime_rows_done += 1

                processed += 1
                row_elapsed = time.perf_counter() - row_t0
                overall_elapsed = time.perf_counter() - start_ts
                rate = processed / overall_elapsed if overall_elapsed > 0 else 0.0
                per_row = (1 / rate) if rate > 0 else float("inf")
                eta = (total - processed) * per_row if math.isfinite(per_row) else float("inf")
                p(f"[{processed}/{total}] id={row_id}: done in {row_elapsed:.1f}s | avg {per_row:.1f}s/row | ETA { _fmt_td(eta) }")

            conn.commit()
            batch_elapsed = time.perf_counter() - batch_t0
            p(f"[batch] committed {len(rows)} rows in {batch_elapsed:.1f}s | total {processed}/{total}")
            logger.info("Processed %d/%d", processed, total)

        p("SDG scoring complete.")
        logger.info("SDG scoring complete.")

# --------------------------------------------------------------------------- #
# Entry
# --------------------------------------------------------------------------- #
def main() -> None:
    p("=== SDG scoring started ===")
    p(f"cwd={os.getcwd()}")
    p(f"LIME status | enabled={ENABLE_LIME} available={LIME_AVAILABLE} outdir={LIME_OUTPUT_DIR} "
      f"rows_limit={LIME_ROWS_LIMIT} sdgs={list(LIME_SDG_LIST)}")
    if ENABLE_LIME and LIME_AVAILABLE:
        os.makedirs(LIME_OUTPUT_DIR, exist_ok=True)
        p(f"LIME output dir ready: {LIME_OUTPUT_DIR}")
    elif ENABLE_LIME and not LIME_AVAILABLE:
        p("LIME disabled: package not installed")

    sdg_df = load_sdgs(SDGS_PATH, SDGS_BATCH)
    sdg_descr = "\n".join(sdg_df["Keywords Merged"].astype(str).tolist())
    sys_prompt = (
        "You are a senior sustainability analyst. Score each initiative from 1 to 10 for all 17 SDGs. "
        "Use these SDG keyword sets as reference:\n" + sdg_descr
    )

    titles = sdg_df["Title"].tolist() if "Title" in sdg_df.columns else [f"SDG {i}" for i in range(1, SDGS_TOTAL + 1)]
    titles = (titles + [f"SDG {i}" for i in range(len(titles) + 1, SDGS_TOTAL + 1)])[:SDGS_TOTAL]
    descs = sdg_df["Keywords Merged"].astype(str).tolist()
    descs = (descs + [""] * (SDGS_TOTAL - len(descs)))[:SDGS_TOTAL]
    sdg_meta: List[Tuple[int, str, str]] = [(i, titles[i - 1], descs[i - 1]) for i in range(1, SDGS_TOTAL + 1)]

    process_cases(sys_prompt, sdg_meta)
    p("=== Finished ===")

if __name__ == "__main__":
    main()



2025-09-02 20:21:49 | === SDG scoring started ===
2025-09-02 20:21:49 | cwd=/delta/JRC/services/jupyter-home/combmar@delta.europa.eu/BC4SDGs
2025-09-02 20:21:49 | LIME status | enabled=True available=True outdir=/delta/JRC/services/jupyter-home/combmar@delta.europa.eu/BC4SDGs/lime rows_limit=10 sdgs=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17]
2025-09-02 20:21:49 | LIME output dir ready: /delta/JRC/services/jupyter-home/combmar@delta.europa.eu/BC4SDGs/lime
2025-09-02 20:21:49 | Database: ./data/pstw_local_database_202507.db
2025-09-02 20:21:49 | Rows total: 306 | already done: 0 | pending: 306
2025-09-02 20:21:49 | ID range: [1, 306] | first pending ids: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
2025-09-02 20:21:49 | Rows to process (SDG scoring): 306 (predicate-only)
2025-09-02 20:21:49,087 - INFO - Rows to process (SDG scoring): 306
2025-09-02 20:21:49,087 - INFO - Rows to process (SDG scoring): 306
2025-09-02 20:21:49 | [1/306] id=1: scoring…
2025-09-02 20:22:00 | [LIME] Start 