# Evaluate CVE→CWE Matching Accuracy (Hybrid RAG)

Goal: evaluate how well our system predicts CWE(s) **when the CVE record already contains explicit CWE labels**.

Method:
- Use all CVEs with at least one explicit `cweId` in `containers.cna.problemTypes`.
- Hide the label(s), predict from the CVE description.
- Accuracy = % where the predicted CWE matches any ground-truth CWE for that CVE.

We report:
- **Top‑1 accuracy** (best prediction)
- **Top‑k accuracy** (ground-truth appears in retrieved top‑k list)

Notes:
- Running over *all* CVEs can take a while; the notebook supports sampling / max limits.

In [1]:
print("Hello World")

Hello World


In [2]:
import json
import re
import math
import time
import random
import shutil
import subprocess
from pathlib import Path

import numpy as np
import scipy.sparse as sp
import xml.etree.ElementTree as ET

PROJECT_ROOT = Path.cwd()
CWE_XML_PATH = PROJECT_ROOT / "data" / "cwec_v4.19.xml"
CVE_ROOT = PROJECT_ROOT / "data" / "cvelistV5-main" / "cves"

print("Project root:", PROJECT_ROOT)
print("CWE XML:", CWE_XML_PATH)
print("CVE root:", CVE_ROOT)
print("Exists?", CWE_XML_PATH.exists(), CVE_ROOT.exists())

Project root: /home/dnfy/Desktop/Fortiss
CWE XML: /home/dnfy/Desktop/Fortiss/data/cwec_v4.19.xml
CVE root: /home/dnfy/Desktop/Fortiss/data/cvelistV5-main/cves
Exists? True True


## 1. Load CWE definitions (ground truth catalog)

We build:
- `cwe_map`: `CWE-<id>` → `{name, description}`
- `cwe_corpus`: list of dense text strings used by the retriever

In [3]:
def parse_cwe_database(xml_path: Path):
    cwe_map = {}
    cwe_corpus = []

    tree = ET.parse(xml_path)
    root = tree.getroot()

    ns = {"cwe": root.tag.split("}")[0].strip("{")} if "}" in root.tag else {}
    xpath = ".//cwe:Weakness" if ns else ".//Weakness"

    for weakness in root.findall(xpath, ns):
        wid = weakness.get("ID")
        wname = weakness.get("Name")

        desc_elem = weakness.find("cwe:Description", ns) if ns else weakness.find("Description")
        description = (desc_elem.text or "").strip() if desc_elem is not None else ""
        if not description:
            description = "No description available."

        cwe_id = f"CWE-{wid}"
        cwe_map[cwe_id] = {"name": wname, "description": description}

        text = f"{cwe_id}: {wname}. {description}"
        cwe_corpus.append({"id": cwe_id, "name": wname, "description": description, "text": text})

    return cwe_map, cwe_corpus


cwe_map, cwe_corpus = parse_cwe_database(CWE_XML_PATH)
print("CWEs:", len(cwe_map))

CWEs: 969


## 2. Build retriever (TF‑IDF fallback + optional SBERT)

- Default: TF‑IDF + cosine similarity (offline, fast, no downloads)
- Optional: SBERT (`sentence-transformers/all-mpnet-base-v2`) if installed

In [4]:
def _normalize_text(s: str) -> str:
    s = (s or "").lower()
    s = re.sub(r"[^a-z0-9]+", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s


def _tokenize(s: str):
    return [t for t in _normalize_text(s).split(" ") if t]


def build_tfidf_index(texts: list[str]):
    n_docs = len(texts)
    doc_term_counts = []
    df = {}

    for text in texts:
        counts = {}
        for tok in _tokenize(text):
            counts[tok] = counts.get(tok, 0) + 1
        doc_term_counts.append(counts)
        for tok in counts.keys():
            df[tok] = df.get(tok, 0) + 1

    vocab = {tok: i for i, tok in enumerate(sorted(df.keys()))}
    n_terms = len(vocab)

    rows, cols, vals = [], [], []
    for r, counts in enumerate(doc_term_counts):
        for tok, tf in counts.items():
            c = vocab[tok]
            rows.append(r)
            cols.append(c)
            vals.append(float(tf))

    tf = sp.csr_matrix((vals, (rows, cols)), shape=(n_docs, n_terms), dtype=np.float32)

    idf_vec = np.empty(n_terms, dtype=np.float32)
    for tok, c in vocab.items():
        idf_vec[c] = math.log((1.0 + n_docs) / (1.0 + df[tok])) + 1.0

    X = tf.multiply(idf_vec)

    row_norm = np.sqrt(X.multiply(X).sum(axis=1)).A1
    row_norm[row_norm == 0] = 1.0
    X = sp.diags(1.0 / row_norm).dot(X)

    return X, vocab, idf_vec


def tfidf_query(text: str, vocab, idf_vec):
    counts = {}
    for tok in _tokenize(text):
        if tok in vocab:
            counts[tok] = counts.get(tok, 0) + 1

    if not counts:
        return sp.csr_matrix((1, len(vocab)), dtype=np.float32)

    rows, cols, vals = [], [], []
    for tok, tf in counts.items():
        rows.append(0)
        cols.append(vocab[tok])
        vals.append(float(tf))

    q_tf = sp.csr_matrix((vals, (rows, cols)), shape=(1, len(vocab)), dtype=np.float32)
    q = q_tf.multiply(idf_vec)

    q_norm = np.sqrt(q.multiply(q).sum(axis=1)).A1
    q_norm[q_norm == 0] = 1.0
    q = q.multiply(1.0 / q_norm[0])

    return q


_cwe_texts = [c["text"] for c in cwe_corpus]
X_tfidf, vocab, idf_vec = build_tfidf_index(_cwe_texts)

RETRIEVER_BACKEND = "tfidf"


def retrieve_topk(query_text: str, k: int = 5):
    q = tfidf_query(query_text, vocab, idf_vec)
    sims = (X_tfidf @ q.T).toarray().ravel()
    top_idx = np.argsort(-sims)[:k]
    return [(int(i), float(sims[int(i)])) for i in top_idx]


# Optional SBERT backend
try:
    from sentence_transformers import SentenceTransformer

    _sbert = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
    _cwe_emb = _sbert.encode(_cwe_texts, normalize_embeddings=True, show_progress_bar=True)
    RETRIEVER_BACKEND = "sbert"

    def retrieve_topk(query_text: str, k: int = 5):
        q = _sbert.encode([query_text], normalize_embeddings=True)[0]
        sims = _cwe_emb @ q
        top_idx = np.argsort(-sims)[:k]
        return [(int(i), float(sims[int(i)])) for i in top_idx]

except Exception:
    pass

print("Retriever backend:", RETRIEVER_BACKEND)

  from .autonotebook import tqdm as notebook_tqdm
Batches: 100%|██████████| 31/31 [00:01<00:00, 17.94it/s]

Retriever backend: sbert





## 3. Load CVEs with explicit CWE labels

We iterate all `CVE-*.json` under the local CVE tree and keep only those with:
- at least one explicit `cweId` that looks like `CWE-<digits>`
- a non-empty English description

In [5]:
CWE_ID_RE = re.compile(r"^CWE-\d+$")


def extract_explicit_cwes(cve_data: dict) -> list[str]:
    cwe_ids = []
    problem_types = cve_data.get("containers", {}).get("cna", {}).get("problemTypes", [])
    for pt in problem_types:
        for desc in pt.get("descriptions", []):
            cwe_id = desc.get("cweId")
            if isinstance(cwe_id, str) and CWE_ID_RE.match(cwe_id):
                cwe_ids.append(cwe_id)
    return sorted(set(cwe_ids))


def extract_cve_description(cve_data: dict) -> str:
    descs = cve_data.get("containers", {}).get("cna", {}).get("descriptions", [])
    for d in descs:
        if d.get("lang") == "en" and d.get("value"):
            return str(d.get("value")).strip()
    for d in descs:
        if d.get("value"):
            return str(d.get("value")).strip()
    return ""


def iter_labeled_cves(cve_root: Path):
    """Yield dicts: {cve_id, description, true_cwes} for CVEs with explicit CWE labels."""
    for p in cve_root.rglob("CVE-*.json"):
        try:
            data = json.loads(p.read_text(encoding="utf-8"))
        except Exception:
            continue

        cve_id = data.get("cveMetadata", {}).get("cveId") or p.stem
        cve_id = str(cve_id)

        true_cwes = extract_explicit_cwes(data)
        if not true_cwes:
            continue

        desc = extract_cve_description(data)
        if not desc:
            continue

        yield {"cve_id": cve_id, "description": desc, "true_cwes": true_cwes}


# Quick count (can take a bit; you can skip this cell if slow)
# n = sum(1 for _ in iter_labeled_cves(CVE_ROOT))
# print("Labeled CVEs:", n)
print("Iterator ready.")

Iterator ready.


## 4. Evaluation

We compute:
- **top‑1**: whether the highest‑ranked retrieved CWE matches any ground truth
- **top‑k**: whether any of the retrieved top‑k CWEs matches any ground truth

Controls:
- `MAX_RECORDS`: set `None` to evaluate all labeled CVEs (can be slow)
- `SHUFFLE`: randomize order before taking `MAX_RECORDS`

In [6]:
TOP_K = 10
MAX_RECORDS = 100   # requested: keep it quick; set None for all labeled CVEs
SHUFFLE = True
SEED = 42

# LLM selector (Ollama)
USE_LLM = True
OLLAMA_MODEL = "mistral:7b-instruct"
OLLAMA_TIMEOUT_S = 180

# NEW Phase 2: Query Abstraction
USE_ABSTRACTION = True  # Enable LLM-based query abstraction before RAG
ABSTRACTION_TIMEOUT_S = 60

records = list(iter_labeled_cves(CVE_ROOT))
print("Total labeled CVEs found:", len(records))

if SHUFFLE:
    random.Random(SEED).shuffle(records)

if MAX_RECORDS is not None:
    records = records[:MAX_RECORDS]

print("Evaluating:", len(records), "records")


def ollama_available() -> bool:
    return shutil.which("ollama") is not None


def build_llm_prompt(cve_description: str, candidates: list[dict]) -> str:
    lines = []
    lines.append("You are a security analyst.")
    lines.append("Return ONLY valid JSON. Do not wrap it in markdown.")
    lines.append("")
    lines.append("VULNERABILITY DESCRIPTION:")
    lines.append((cve_description or "").strip())
    lines.append("")
    lines.append("CANDIDATE CWE DEFINITIONS:")

    for i, c in enumerate(candidates, start=1):
        lines.append("")
        lines.append(f"{i}. {c['cwe_id']} — {c.get('name','')}")
        lines.append(f"Definition: {c.get('description','')}")

    lines.append("")
    lines.append(
        "Task: Choose the SINGLE best CWE from the candidates. "
        "If none fit well, output best_cwe as 'NONE'. "
        "Respond in JSON with keys: best_cwe, confidence (0..1), rationale."
    )

    return "\n".join(lines)


def run_ollama(prompt: str, model: str, timeout_s: int):
    proc = subprocess.run(
        ["ollama", "run", model],
        input=prompt.encode("utf-8"),
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        timeout=timeout_s,
    )
    if proc.returncode != 0:
        raise RuntimeError(proc.stderr.decode("utf-8", errors="ignore"))
    return proc.stdout.decode("utf-8", errors="ignore").strip()


def parse_llm_json(text: str) -> dict | None:
    if not text:
        return None

    # Try to extract the first JSON object from the output
    m = re.search(r"\{[\s\S]*\}", text)
    if not m:
        return None

    blob = m.group(0)
    try:
        return json.loads(blob)
    except Exception:
        return None


def normalize_cwe(s: str | None) -> str | None:
    if not s:
        return None
    s = str(s).strip().upper()
    if s == "NONE":
        return "NONE"

    # Accept formats like "CWE-79" or "79"
    if re.fullmatch(r"CWE-\d+", s):
        return s
    if re.fullmatch(r"\d+", s):
        return f"CWE-{s}"

    # last resort: search inside
    m = re.search(r"CWE-\d+", s)
    return m.group(0) if m else None


def build_abstraction_prompt(cve_description: str) -> str:
    """Build prompt to convert specific CVE description into abstract security pattern."""
    lines = []
    lines.append("You are a security analyst expert in vulnerability classification.")
    lines.append("Return ONLY the abstracted description text. Do not add explanations or markdown.")
    lines.append("")
    lines.append("TASK: Convert this SPECIFIC vulnerability description into an ABSTRACT security weakness pattern.")
    lines.append("")
    lines.append("SPECIFIC VULNERABILITY:")
    lines.append(cve_description.strip() if cve_description else "(missing)")
    lines.append("")
    lines.append("INSTRUCTIONS:")
    lines.append("- Remove product names, version numbers, and implementation details")
    lines.append("- Focus on the UNDERLYING weakness type (e.g., 'buffer overflow' → 'out-of-bounds write')")
    lines.append("- Use generic security terminology (not vendor-specific jargon)")
    lines.append("- Keep it concise (1-3 sentences)")
    lines.append("")
    lines.append("ABSTRACT WEAKNESS PATTERN:")
    return "\n".join(lines)


def abstract_cve_description(cve_description: str, model: str, timeout_s: int = 60):
    """Use Ollama to convert specific CVE description into abstract security pattern.
    
    Returns: abstracted_text or None if fails
    """
    if not ollama_available():
        return None
    
    prompt = build_abstraction_prompt(cve_description)
    
    try:
        proc = subprocess.run(
            ["ollama", "run", model],
            input=prompt.encode("utf-8"),
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            timeout=timeout_s,
        )
        if proc.returncode != 0:
            return None
        
        raw_output = proc.stdout.decode("utf-8", errors="ignore").strip()
        
        # Clean up: remove any markdown formatting or extra explanations
        abstracted = raw_output.strip()
        
        # If LLM added markdown code blocks, extract the content
        if abstracted.startswith("```"):
            lines = abstracted.split("\n")
            abstracted = "\n".join([l for l in lines if not l.strip().startswith("```")])
        
        abstracted = abstracted.strip()
        
        return abstracted if abstracted else None
        
    except Exception:
        return None


def predict_topk_from_description(desc: str, k: int = 5):
    idx_scores = retrieve_topk(desc, k=k)
    preds = []
    for idx, score in idx_scores:
        c = cwe_corpus[idx]
        preds.append({
            "cwe_id": c["id"],
            "score": float(score),
            "name": c["name"],
            "description": c["description"],
        })
    return preds


def llm_choose_best_cwe(desc: str, candidates: list[dict]):
    """Return chosen CWE id (CWE-123 / NONE) and raw model output."""
    prompt = build_llm_prompt(desc, candidates)
    out = run_ollama(prompt, model=OLLAMA_MODEL, timeout_s=OLLAMA_TIMEOUT_S)
    parsed = parse_llm_json(out)
    best = normalize_cwe(parsed.get("best_cwe") if parsed else None)
    return best, out


def evaluate(records, k: int = 5, progress_every: int = 20):
    t0 = time.time()

    n = 0
    top1_correct = 0
    topk_correct = 0

    llm_used = 0
    llm_correct = 0
    llm_none = 0
    llm_fail = 0
    
    # NEW: abstraction tracking
    abstraction_used = 0
    abstraction_fail = 0

    # optional: keep some mistakes for inspection
    mistakes = []

    if USE_LLM and not ollama_available():
        print("WARNING: USE_LLM=True but 'ollama' was not found on PATH; disabling LLM evaluation.")
    
    if USE_ABSTRACTION and not ollama_available():
        print("WARNING: USE_ABSTRACTION=True but 'ollama' was not found on PATH; disabling abstraction.")

    use_llm = USE_LLM and ollama_available()
    use_abstraction = USE_ABSTRACTION and ollama_available()

    for r in records:
        n += 1
        true_set = set(r["true_cwes"])
        
        # NEW: Apply abstraction before retrieval if enabled
        query_for_retrieval = r["description"]
        if use_abstraction:
            abstracted = abstract_cve_description(r["description"], model=OLLAMA_MODEL, timeout_s=ABSTRACTION_TIMEOUT_S)
            if abstracted:
                query_for_retrieval = abstracted
                abstraction_used += 1
            else:
                abstraction_fail += 1
                # Fallback to original description
        
        preds = predict_topk_from_description(query_for_retrieval, k=k)

        pred_ids = [p["cwe_id"] for p in preds]
        top1 = pred_ids[0] if pred_ids else None

        if top1 in true_set:
            top1_correct += 1

        if any(pid in true_set for pid in pred_ids):
            topk_correct += 1
        else:
            if len(mistakes) < 20:
                mistakes.append({
                    "cve_id": r["cve_id"],
                    "true": r["true_cwes"],
                    "pred_topk": pred_ids,
                    "desc": r["description"][:300],
                })

        # LLM selector accuracy (choose 1 from the top-k candidates)
        if use_llm and preds:
            try:
                llm_used += 1
                best_cwe, raw = llm_choose_best_cwe(r["description"], preds)
                if best_cwe == "NONE":
                    llm_none += 1
                elif best_cwe in true_set:
                    llm_correct += 1
            except Exception:
                llm_fail += 1

        if progress_every and n % progress_every == 0:
            elapsed = time.time() - t0
            msg = f"[{n}/{len(records)}] top1={top1_correct/n:.3f} top{k}={topk_correct/n:.3f}"
            if use_abstraction:
                msg += f" abstraction={abstraction_used}/{n} (fail={abstraction_fail})"
            if use_llm:
                llm_acc = (llm_correct / llm_used) if llm_used else 0.0
                msg += f" llm_top1={llm_acc:.3f} (used={llm_used}, none={llm_none}, fail={llm_fail})"
            msg += f" elapsed={elapsed:.1f}s"
            print(msg)

    elapsed = time.time() - t0

    out = {
        "n": n,
        "top1_acc": top1_correct / n if n else 0.0,
        f"top{k}_acc": topk_correct / n if n else 0.0,
        "elapsed_s": elapsed,
        "backend": RETRIEVER_BACKEND,
        "mistakes_sample": mistakes,
    }
    
    if use_abstraction:
        out.update({
            "abstraction_enabled": True,
            "abstraction_used": abstraction_used,
            "abstraction_fail": abstraction_fail,
        })

    if use_llm:
        out.update({
            "llm_model": OLLAMA_MODEL,
            "llm_used": llm_used,
            "llm_top1_acc": (llm_correct / llm_used) if llm_used else 0.0,
            "llm_top1_acc_overall": (llm_correct / n) if n else 0.0,
            "llm_none": llm_none,
            "llm_fail": llm_fail,
        })

    return out


metrics = evaluate(records, k=TOP_K)
metrics

Total labeled CVEs found: 101404
Evaluating: 100 records
[20/100] top1=0.200 top10=0.700 abstraction=20/20 (fail=0) llm_top1=0.600 (used=20, none=0, fail=0) elapsed=64.8s
[40/100] top1=0.175 top10=0.575 abstraction=40/40 (fail=0) llm_top1=0.450 (used=40, none=0, fail=0) elapsed=130.8s
[60/100] top1=0.183 top10=0.550 abstraction=60/60 (fail=0) llm_top1=0.417 (used=60, none=0, fail=0) elapsed=196.6s
[80/100] top1=0.250 top10=0.575 abstraction=80/80 (fail=0) llm_top1=0.450 (used=80, none=0, fail=0) elapsed=261.5s
[100/100] top1=0.210 top10=0.580 abstraction=100/100 (fail=0) llm_top1=0.440 (used=100, none=0, fail=0) elapsed=324.8s


{'n': 100,
 'top1_acc': 0.21,
 'top10_acc': 0.58,
 'elapsed_s': 324.7723343372345,
 'backend': 'sbert',
 'mistakes_sample': [{'cve_id': 'CVE-2024-4748',
   'true': ['CWE-78'],
   'pred_topk': ['CWE-535',
    'CWE-84',
    'CWE-114',
    'CWE-553',
    'CWE-472',
    'CWE-537',
    'CWE-11',
    'CWE-531',
    'CWE-536',
    'CWE-81'],
   'desc': 'The CRUDDIY project is vulnerable to shell command injection via sending a crafted POST request to the application server.\xa0\nThe exploitation risk is limited since CRUDDIY is meant to be launched locally. Nevertheless, a user with the project running on their computer might visit a website which woul'},
  {'cve_id': 'CVE-2025-30151',
   'true': ['CWE-20'],
   'pred_topk': ['CWE-787',
    'CWE-130',
    'CWE-1284',
    'CWE-645',
    'CWE-1320',
    'CWE-179',
    'CWE-805',
    'CWE-648',
    'CWE-307',
    'CWE-488'],
   'desc': "Shopware is an open commerce platform. It's possible to pass long passwords that leads to Denial Of Service via