# Evaluate CVE→CWE Matching Accuracy (Hybrid RAG)

Goal: evaluate how well our system predicts CWE(s) **when the CVE record already contains explicit CWE labels**.

Method:
- Use all CVEs with at least one explicit `cweId` in `containers.cna.problemTypes`.
- Hide the label(s), predict from the CVE description.
- Accuracy = % where the predicted CWE matches any ground-truth CWE for that CVE.

We report:
- **Top‑1 accuracy** (best prediction)
- **Top‑k accuracy** (ground-truth appears in retrieved top‑k list)

Notes:
- Running over *all* CVEs can take a while; the notebook supports sampling / max limits.

In [1]:
print("Hello World")

Hello World


In [2]:
import json
import re
import math
import time
import random
import shutil
import subprocess
from pathlib import Path

import numpy as np
import scipy.sparse as sp
import xml.etree.ElementTree as ET

PROJECT_ROOT = Path.cwd()
CWE_XML_PATH = PROJECT_ROOT / "data" / "cwec_v4.19.xml"
CVE_ROOT = PROJECT_ROOT / "data" / "cvelistV5-main" / "cves"

print("Project root:", PROJECT_ROOT)
print("CWE XML:", CWE_XML_PATH)
print("CVE root:", CVE_ROOT)
print("Exists?", CWE_XML_PATH.exists(), CVE_ROOT.exists())

Project root: /home/dnfy/Desktop/Fortiss
CWE XML: /home/dnfy/Desktop/Fortiss/data/cwec_v4.19.xml
CVE root: /home/dnfy/Desktop/Fortiss/data/cvelistV5-main/cves
Exists? True True


## 1. Load CWE definitions (ground truth catalog)

We build:
- `cwe_map`: `CWE-<id>` → `{name, description}`
- `cwe_corpus`: list of dense text strings used by the retriever

In [3]:
def parse_cwe_database(xml_path: Path):
    """PHASE 1: Enriched CWE parsing with extended descriptions and consequences"""
    cwe_map = {}
    cwe_corpus = []

    tree = ET.parse(xml_path)
    root = tree.getroot()

    ns = {"cwe": root.tag.split("}")[0].strip("{")} if "}" in root.tag else {}
    xpath = ".//cwe:Weakness" if ns else ".//Weakness"

    for weakness in root.findall(xpath, ns):
        wid = weakness.get("ID")
        wname = weakness.get("Name")

        # Primary description
        desc_elem = weakness.find("cwe:Description", ns) if ns else weakness.find("Description")
        description = (desc_elem.text or "").strip() if desc_elem is not None else ""
        if not description:
            description = "No description available."

        # Extended description (more detailed)
        ext_desc_elem = weakness.find("cwe:Extended_Description", ns) if ns else weakness.find("Extended_Description")
        extended_description = ""
        if ext_desc_elem is not None:
            parts = []
            if ext_desc_elem.text:
                parts.append(ext_desc_elem.text.strip())
            for child in ext_desc_elem:
                if child.text:
                    parts.append(child.text.strip())
                if child.tail:
                    parts.append(child.tail.strip())
            extended_description = " ".join(p for p in parts if p)

        # Common consequences
        consequences = []
        cons_elem = weakness.find("cwe:Common_Consequences", ns) if ns else weakness.find("Common_Consequences")
        if cons_elem is not None:
            for cons in cons_elem.findall("cwe:Consequence" if ns else "Consequence", ns):
                scope_elem = cons.find("cwe:Scope" if ns else "Scope", ns)
                impact_elem = cons.find("cwe:Impact" if ns else "Impact", ns)
                if scope_elem is not None and scope_elem.text:
                    consequences.append(scope_elem.text.strip())
                if impact_elem is not None and impact_elem.text:
                    consequences.append(impact_elem.text.strip())

        cwe_id = f"CWE-{wid}"
        cwe_map[cwe_id] = {
            "name": wname,
            "description": description,
            "extended_description": extended_description
        }

        # Enriched text for retrieval
        text_parts = [f"{cwe_id}: {wname}", description]
        if extended_description:
            text_parts.append(extended_description)
        if consequences:
            text_parts.append("Consequences: " + ", ".join(set(consequences)))

        text = ". ".join(text_parts)
        cwe_corpus.append({
            "id": cwe_id,
            "name": wname,
            "description": description,
            "extended_description": extended_description,
            "text": text
        })

    return cwe_map, cwe_corpus


cwe_map, cwe_corpus = parse_cwe_database(CWE_XML_PATH)
print("CWEs:", len(cwe_map), "(enriched with extended fields)")

CWEs: 969 (enriched with extended fields)


## 2. PHASE 1: Enhanced Hybrid Retrieval

**Improvements:**
1. **BM25 + Dense Hybrid** with RRF fusion
2. **Query Expansion** with security vocabulary
3. **Enriched CWE Documents** with extended descriptions
4. **Two-Stage Retrieval** architecture

In [4]:
# ===========================
# PHASE 1: Query Expansion
# ===========================

SECURITY_EXPANSIONS = {
    # Memory corruption
    "buffer overflow": ["out of bounds write", "memory corruption", "buffer error", "heap overflow", "stack overflow"],
    "heap overflow": ["buffer overflow", "memory corruption", "out of bounds write"],
    "stack overflow": ["buffer overflow", "memory corruption", "out of bounds write"],
    "use after free": ["memory corruption", "dangling pointer", "temporal memory safety"],
    "double free": ["memory corruption", "heap corruption"],
    "null pointer": ["null dereference", "null pointer dereference"],
    
    # Injection attacks
    "sql injection": ["query injection", "improper neutralization", "code injection", "command injection"],
    "command injection": ["os command injection", "code injection", "improper neutralization"],
    "xss": ["cross site scripting", "script injection", "improper neutralization"],
    "cross site scripting": ["xss", "script injection", "improper neutralization"],
    "ldap injection": ["query injection", "improper neutralization"],
    "xpath injection": ["query injection", "improper neutralization"],
    
    # Authentication & Authorization
    "authentication bypass": ["improper authentication", "missing authentication", "broken authentication"],
    "privilege escalation": ["improper privilege management", "improper authorization", "vertical privilege escalation"],
    "authorization": ["access control", "improper authorization", "privilege management"],
    "access control": ["authorization", "improper access control", "missing access control"],
    
    # Cryptographic issues
    "weak encryption": ["inadequate encryption strength", "weak cryptography", "insecure cryptographic algorithm"],
    "weak hash": ["weak cryptographic hash", "inadequate encryption strength"],
    "hard coded": ["hardcoded credentials", "embedded credentials", "plaintext storage"],
    "plaintext": ["cleartext storage", "cleartext transmission", "missing encryption"],
    
    # Path/file issues
    "path traversal": ["directory traversal", "path injection", "improper limitation of pathname"],
    "directory traversal": ["path traversal", "path injection"],
    "file upload": ["unrestricted file upload", "improper validation of file"],
    
    # Input validation
    "integer overflow": ["numeric overflow", "wrap around", "integer wraparound"],
    "format string": ["format string vulnerability", "uncontrolled format string"],
    "race condition": ["time of check time of use", "toctou", "concurrent execution"],
    
    # Web-specific
    "csrf": ["cross site request forgery", "session riding"],
    "cross site request forgery": ["csrf", "session riding"],
    "open redirect": ["url redirection", "improper url redirect"],
    "ssrf": ["server side request forgery", "improper url handling"],
    
    # Misc
    "denial of service": ["resource exhaustion", "uncontrolled resource consumption", "dos"],
    "dos": ["denial of service", "resource exhaustion"],
    "information disclosure": ["information exposure", "sensitive information exposure"],
    "deserialization": ["unsafe deserialization", "untrusted deserialization"],
}


def expand_query(query_text: str, max_expansions: int = 3) -> str:
    """Expand query with security-specific synonyms"""
    query_lower = query_text.lower()
    expanded_terms = []
    
    for trigger, expansions in SECURITY_EXPANSIONS.items():
        if trigger in query_lower:
            expanded_terms.extend(expansions[:max_expansions])
    
    if expanded_terms:
        expanded_terms = list(dict.fromkeys(expanded_terms))
        return query_text + " " + " ".join(expanded_terms)
    
    return query_text


# ===========================
# PHASE 1: BM25 + Dense Retrieval
# ===========================

def _normalize_text(s: str) -> str:
    s = (s or "").lower()
    s = re.sub(r"[^a-z0-9]+", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s


def _tokenize(s: str):
    return [t for t in _normalize_text(s).split(" ") if t]


_cwe_texts = [c["text"] for c in cwe_corpus]
_tokenized_corpus = [_tokenize(text) for text in _cwe_texts]

# BM25 indexing
try:
    from rank_bm25 import BM25Okapi
    _bm25 = BM25Okapi(_tokenized_corpus)
    BM25_AVAILABLE = True
    print(f"BM25 index built: {len(_tokenized_corpus)} documents")
except ImportError:
    BM25_AVAILABLE = False
    print("rank-bm25 not installed; BM25 disabled")


def retrieve_bm25(query_text: str, top_k: int = 20):
    if not BM25_AVAILABLE:
        return []
    query_tokens = _tokenize(query_text)
    if not query_tokens:
        return []
    scores = _bm25.get_scores(query_tokens)
    top_idx = np.argsort(-scores)[:top_k]
    return [(int(i), float(scores[i])) for i in top_idx if scores[i] > 0]


# Dense (SBERT) retrieval
DENSE_AVAILABLE = False
_sbert_model = None
_cwe_emb = None

try:
    from sentence_transformers import SentenceTransformer
    _sbert_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
    _cwe_emb = _sbert_model.encode(_cwe_texts, normalize_embeddings=True, show_progress_bar=True)
    DENSE_AVAILABLE = True
    print(f"Dense embeddings built: {_cwe_emb.shape}")
except Exception:
    print("SBERT not available")


def retrieve_dense(query_text: str, top_k: int = 20):
    if not DENSE_AVAILABLE:
        return []
    q = _sbert_model.encode([query_text], normalize_embeddings=True)[0]
    sims = _cwe_emb @ q
    top_idx = np.argsort(-sims)[:top_k]
    return [(int(i), float(sims[i])) for i in top_idx if sims[i] > 0]


# ===========================
# PHASE 1: Reciprocal Rank Fusion
# ===========================

def reciprocal_rank_fusion(rankings: list[list[tuple[int, float]]], k: int = 60):
    """Fuse multiple rankings with RRF"""
    rrf_scores = {}
    for ranking in rankings:
        for rank, (doc_id, _) in enumerate(ranking, start=1):
            if doc_id not in rrf_scores:
                rrf_scores[doc_id] = 0.0
            rrf_scores[doc_id] += 1.0 / (k + rank)
    fused = sorted(rrf_scores.items(), key=lambda x: -x[1])
    return fused


# ===========================
# PHASE 1: Main Retrieval Function
# ===========================

def retrieve_topk(query_text: str, k: int = 5):
    """Hybrid retrieval with query expansion + RRF fusion"""
    # Step 1: Expand query
    expanded_query = expand_query(query_text)
    
    # Step 2: Retrieve from available backends
    rankings = []
    
    if BM25_AVAILABLE:
        bm25_results = retrieve_bm25(expanded_query, top_k=20)
        if bm25_results:
            rankings.append(bm25_results)
    
    if DENSE_AVAILABLE:
        dense_results = retrieve_dense(query_text, top_k=20)
        if dense_results:
            rankings.append(dense_results)
    
    # Step 3: Fusion
    if len(rankings) >= 2:
        fused = reciprocal_rank_fusion(rankings, k=60)
        return fused[:k]
    elif len(rankings) == 1:
        return rankings[0][:k]
    else:
        return []


# Report status
backends = []
if BM25_AVAILABLE:
    backends.append("BM25")
if DENSE_AVAILABLE:
    backends.append("Dense(SBERT)")

if len(backends) >= 2:
    print(f"✓ Hybrid retrieval: {' + '.join(backends)} with RRF")
elif len(backends) == 1:
    print(f"⚠ Single mode: {backends[0]}")
else:
    print("✗ No retrievers available")

BM25 index built: 969 documents


  from .autonotebook import tqdm as notebook_tqdm
Batches: 100%|██████████| 31/31 [00:04<00:00,  7.17it/s]

Dense embeddings built: (969, 768)
✓ Hybrid retrieval: BM25 + Dense(SBERT) with RRF





## 3. Load CVEs with explicit CWE labels

We iterate all `CVE-*.json` under the local CVE tree and keep only those with:
- at least one explicit `cweId` that looks like `CWE-<digits>`
- a non-empty English description

In [5]:
CWE_ID_RE = re.compile(r"^CWE-\d+$")


def extract_explicit_cwes(cve_data: dict) -> list[str]:
    cwe_ids = []
    problem_types = cve_data.get("containers", {}).get("cna", {}).get("problemTypes", [])
    for pt in problem_types:
        for desc in pt.get("descriptions", []):
            cwe_id = desc.get("cweId")
            if isinstance(cwe_id, str) and CWE_ID_RE.match(cwe_id):
                cwe_ids.append(cwe_id)
    return sorted(set(cwe_ids))


def extract_cve_description(cve_data: dict) -> str:
    descs = cve_data.get("containers", {}).get("cna", {}).get("descriptions", [])
    for d in descs:
        if d.get("lang") == "en" and d.get("value"):
            return str(d.get("value")).strip()
    for d in descs:
        if d.get("value"):
            return str(d.get("value")).strip()
    return ""


def iter_labeled_cves(cve_root: Path):
    """Yield dicts: {cve_id, description, true_cwes} for CVEs with explicit CWE labels."""
    for p in cve_root.rglob("CVE-*.json"):
        try:
            data = json.loads(p.read_text(encoding="utf-8"))
        except Exception:
            continue

        cve_id = data.get("cveMetadata", {}).get("cveId") or p.stem
        cve_id = str(cve_id)

        true_cwes = extract_explicit_cwes(data)
        if not true_cwes:
            continue

        desc = extract_cve_description(data)
        if not desc:
            continue

        yield {"cve_id": cve_id, "description": desc, "true_cwes": true_cwes}


# Quick count (can take a bit; you can skip this cell if slow)
# n = sum(1 for _ in iter_labeled_cves(CVE_ROOT))
# print("Labeled CVEs:", n)
print("Iterator ready.")

Iterator ready.


## 4. Evaluation

We compute:
- **top‑1**: whether the highest‑ranked retrieved CWE matches any ground truth
- **top‑k**: whether any of the retrieved top‑k CWEs matches any ground truth

Controls:
- `MAX_RECORDS`: set `None` to evaluate all labeled CVEs (can be slow)
- `SHUFFLE`: randomize order before taking `MAX_RECORDS`

In [6]:
TOP_K = 5
MAX_RECORDS = 100   # requested: keep it quick; set None for all labeled CVEs
SHUFFLE = True
SEED = 42

# LLM selector (Ollama)
USE_LLM = True
OLLAMA_MODEL = "mistral:7b-instruct"
OLLAMA_TIMEOUT_S = 180

records = list(iter_labeled_cves(CVE_ROOT))
print("Total labeled CVEs found:", len(records))

if SHUFFLE:
    random.Random(SEED).shuffle(records)

if MAX_RECORDS is not None:
    records = records[:MAX_RECORDS]

print("Evaluating:", len(records), "records")


def ollama_available() -> bool:
    return shutil.which("ollama") is not None


def build_llm_prompt(cve_description: str, candidates: list[dict]) -> str:
    lines = []
    lines.append("You are a security analyst.")
    lines.append("Return ONLY valid JSON. Do not wrap it in markdown.")
    lines.append("")
    lines.append("VULNERABILITY DESCRIPTION:")
    lines.append((cve_description or "").strip())
    lines.append("")
    lines.append("CANDIDATE CWE DEFINITIONS:")

    for i, c in enumerate(candidates, start=1):
        lines.append("")
        lines.append(f"{i}. {c['cwe_id']} — {c.get('name','')}")
        lines.append(f"Definition: {c.get('description','')}")

    lines.append("")
    lines.append(
        "Task: Choose the SINGLE best CWE from the candidates. "
        "If none fit well, output best_cwe as 'NONE'. "
        "Respond in JSON with keys: best_cwe, confidence (0..1), rationale."
    )

    return "\n".join(lines)


def run_ollama(prompt: str, model: str, timeout_s: int):
    proc = subprocess.run(
        ["ollama", "run", model],
        input=prompt.encode("utf-8"),
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE,
        timeout=timeout_s,
    )
    if proc.returncode != 0:
        raise RuntimeError(proc.stderr.decode("utf-8", errors="ignore"))
    return proc.stdout.decode("utf-8", errors="ignore").strip()


def parse_llm_json(text: str) -> dict | None:
    if not text:
        return None

    # Try to extract the first JSON object from the output
    m = re.search(r"\{[\s\S]*\}", text)
    if not m:
        return None

    blob = m.group(0)
    try:
        return json.loads(blob)
    except Exception:
        return None


def normalize_cwe(s: str | None) -> str | None:
    if not s:
        return None
    s = str(s).strip().upper()
    if s == "NONE":
        return "NONE"

    # Accept formats like "CWE-79" or "79"
    if re.fullmatch(r"CWE-\d+", s):
        return s
    if re.fullmatch(r"\d+", s):
        return f"CWE-{s}"

    # last resort: search inside
    m = re.search(r"CWE-\d+", s)
    return m.group(0) if m else None


def predict_topk_from_description(desc: str, k: int = 5):
    idx_scores = retrieve_topk(desc, k=k)
    preds = []
    for idx, score in idx_scores:
        c = cwe_corpus[idx]
        preds.append({
            "cwe_id": c["id"],
            "score": float(score),
            "name": c["name"],
            "description": c["description"],
        })
    return preds


def llm_choose_best_cwe(desc: str, candidates: list[dict]):
    """Return chosen CWE id (CWE-123 / NONE) and raw model output."""
    prompt = build_llm_prompt(desc, candidates)
    out = run_ollama(prompt, model=OLLAMA_MODEL, timeout_s=OLLAMA_TIMEOUT_S)
    parsed = parse_llm_json(out)
    best = normalize_cwe(parsed.get("best_cwe") if parsed else None)
    return best, out


def get_retriever_backend():
    """Determine which retrieval backend is active"""
    if BM25_AVAILABLE and DENSE_AVAILABLE:
        return "hybrid-bm25+dense"
    elif BM25_AVAILABLE:
        return "bm25-only"
    elif DENSE_AVAILABLE:
        return "dense-only"
    else:
        return "none"


def evaluate(records, k: int = 5, progress_every: int = 20):
    t0 = time.time()

    n = 0
    top1_correct = 0
    topk_correct = 0

    llm_used = 0
    llm_correct = 0
    llm_none = 0
    llm_fail = 0

    # optional: keep some mistakes for inspection
    mistakes = []

    if USE_LLM and not ollama_available():
        print("WARNING: USE_LLM=True but 'ollama' was not found on PATH; disabling LLM evaluation.")

    use_llm = USE_LLM and ollama_available()

    for r in records:
        n += 1
        true_set = set(r["true_cwes"])
        preds = predict_topk_from_description(r["description"], k=k)

        pred_ids = [p["cwe_id"] for p in preds]
        top1 = pred_ids[0] if pred_ids else None

        if top1 in true_set:
            top1_correct += 1

        if any(pid in true_set for pid in pred_ids):
            topk_correct += 1
        else:
            if len(mistakes) < 20:
                mistakes.append({
                    "cve_id": r["cve_id"],
                    "true": r["true_cwes"],
                    "pred_topk": pred_ids,
                    "desc": r["description"][:300],
                })

        # LLM selector accuracy (choose 1 from the top-k candidates)
        if use_llm and preds:
            try:
                llm_used += 1
                best_cwe, raw = llm_choose_best_cwe(r["description"], preds)
                if best_cwe == "NONE":
                    llm_none += 1
                elif best_cwe in true_set:
                    llm_correct += 1
            except Exception:
                llm_fail += 1

        if progress_every and n % progress_every == 0:
            elapsed = time.time() - t0
            msg = f"[{n}/{len(records)}] top1={top1_correct/n:.3f} top{k}={topk_correct/n:.3f}"
            if use_llm:
                llm_acc = (llm_correct / llm_used) if llm_used else 0.0
                msg += f" llm_top1={llm_acc:.3f} (used={llm_used}, none={llm_none}, fail={llm_fail})"
            msg += f" elapsed={elapsed:.1f}s"
            print(msg)

    elapsed = time.time() - t0

    out = {
        "n": n,
        "top1_acc": top1_correct / n if n else 0.0,
        f"top{k}_acc": topk_correct / n if n else 0.0,
        "elapsed_s": elapsed,
        "backend": get_retriever_backend(),
        "mistakes_sample": mistakes,
    }

    if use_llm:
        out.update({
            "llm_model": OLLAMA_MODEL,
            "llm_used": llm_used,
            "llm_top1_acc": (llm_correct / llm_used) if llm_used else 0.0,
            "llm_top1_acc_overall": (llm_correct / n) if n else 0.0,
            "llm_none": llm_none,
            "llm_fail": llm_fail,
        })

    return out


metrics = evaluate(records, k=TOP_K)
metrics

Total labeled CVEs found: 101404
Evaluating: 100 records
[20/100] top1=0.600 top5=0.800 llm_top1=0.650 (used=20, none=0, fail=0) elapsed=42.3s
[40/100] top1=0.450 top5=0.600 llm_top1=0.500 (used=40, none=0, fail=0) elapsed=84.7s
[60/100] top1=0.367 top5=0.517 llm_top1=0.450 (used=60, none=0, fail=0) elapsed=129.5s
[80/100] top1=0.425 top5=0.562 llm_top1=0.463 (used=80, none=0, fail=0) elapsed=170.7s
[100/100] top1=0.440 top5=0.560 llm_top1=0.480 (used=100, none=0, fail=0) elapsed=209.9s


{'n': 100,
 'top1_acc': 0.44,
 'top5_acc': 0.56,
 'elapsed_s': 209.87854194641113,
 'backend': 'hybrid-bm25+dense',
 'mistakes_sample': [{'cve_id': 'CVE-2025-30151',
   'true': ['CWE-20'],
   'pred_topk': ['CWE-645', 'CWE-640', 'CWE-589', 'CWE-522', 'CWE-1328'],
   'desc': "Shopware is an open commerce platform. It's possible to pass long passwords that leads to Denial Of Service via forms in Storefront forms or Store-API. This vulnerability is fixed in 6.6.10.3 or 6.5.8.17. For older versions of 6.4, corresponding security measures are also available via a plugin. For"},
  {'cve_id': 'CVE-2021-27410',
   'true': ['CWE-787'],
   'pred_topk': ['CWE-1328', 'CWE-1262', 'CWE-124', 'CWE-1277', 'CWE-589'],
   'desc': 'The affected product is vulnerable to an out-of-bounds write, which may result in corruption of data or code execution on the Welch Allyn medical device management tools (Welch Allyn Service Tool: versions prior to v1.10, Welch Allyn Connex Device Integration Suite – Network Co