In [1]:
import os
import re
import json
from tqdm import tqdm
from langchain_ollama.llms import OllamaLLM

In [2]:
sft_system_message = """You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all malwares referenced or implied in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe malwares.
                    To describe a malware you should provide the fields id, type, name and is_family.
                    Instead of using UUID in the id field, use the rule type--name for generating ids.
                    If no malwares are identified return a json with an empty list "objects".
                    Identify all malwares in the folowing CTI report: """

In [3]:
def load_json(path:str, filename:str):
    with open(os.path.join(path, filename), mode="r", encoding="utf-8") as f:
        return json.load(f)
    
def format_example(example:dict, system_message):
        formatted_example = [
            {"role": "assistant", "content": system_message},
            {"role": "user", "content": example["input"]},
            {"role": "assistant", "content": json.dumps(example["output"])}
        ]
        return formatted_example

In [4]:
test_path = "/mnt/data/openCTI/splitted-io-pairs/test"
inputs = []
outputs = []
include_cti_type = ["malware"]

for file in os.listdir(test_path):
    cti_type = file.split("--")[0]
    if cti_type not in include_cti_type:
        continue
    example = load_json(test_path, file)
    inputs.append(example["input"])
    outputs.append(example["output"])

In [5]:
model = OllamaLLM(model="gpt-oss:120b",
                  num_ctx=128000,
                  num_predict=5000)

In [7]:
preds = [model.invoke(sft_system_message + user_input) for user_input in tqdm(inputs)]

100%|██████████| 307/307 [57:21<00:00, 11.21s/it] 


In [8]:
preds4eval = []
failed_preds = []
pattern = r'\{\s*"id"\s*:\s*"[^"]*"\s*,\s*"type"\s*:\s*"[^"]*"\s*,\s*"name"\s*:\s*"[^"]*"\s*,\s*"is_family"\s*:\s*(?:true|false|null|"(?:[^"]*)"|[-+]?\d+(?:\.\d+)?)\s*\}'

for p in preds:

    p = p.lower().replace("\t", "")

    try:
        objects = json.loads(p)["objects"]
    except:
        objects = re.findall(pattern, p)
        if not objects and '"objects": []' not in p:
            failed_preds.append(p)
        elif objects:
            objects = [json.loads(obj) for obj in objects]
    
    preds4eval.append(
            {
                "objects":objects
                }
        )

In [9]:
print(f"Percenrage of failed json outputs: {'{:.1f}'.format(100 * len(failed_preds) / len(inputs))}%")

Percenrage of failed json outputs: 0.0%


In [10]:
# Post processing
processed_preds4eval = []

def fix_malware_id(wrong_id: str) -> str:
    """
    Fixes malformed malware IDs according to the observed patterns.
    """

    # Trim spaces
    s = wrong_id.strip().lower()

    # Remove leading underscores or hyphens
    s = re.sub(r'^[-_]+', '', s)

    # Remove redundant 'malware' if it’s at the start but malformed
    s = re.sub(r'^(malware[-_]+)', '', s)

    # Handle duplicated name parts (e.g., 'fatboy--fatboy')
    parts = re.split(r'--+', s)
    if len(parts) == 2 and parts[0] == parts[1]:
        s = parts[0]

    # Prepend 'malware--'
    corrected = f"malware--{s}"

    # Ensure only one double dash after 'malware'
    corrected = re.sub(r'^malware-+', 'malware--', corrected)

    return corrected

# Step 1
for p in preds4eval:
    objects = []
    for obj in p["objects"]:
        # Step 1
        if "id" in obj.keys():
            ID = fix_malware_id(obj["id"].strip())

        if "name" in obj.keys():
            NAME = obj["name"]
        else:
            NAME = ID.split("malware--")[-1]

        if "is_family" in obj.keys():
            IS_FAMILY = obj["is_family"]
        else:
            IS_FAMILY = False

        objects.append(
            {
                "id":ID,
                "type":"malware",
                "name":NAME,
                "is_family":IS_FAMILY
            }
        )

    processed_preds4eval.append(
                {
                    "objects":objects
                }
            )

In [11]:
from evaluation.stix_evaluator import STIXEvaluator
import warnings
warnings.filterwarnings(action="ignore")

evaluator = STIXEvaluator(comparison_values=["type", "name"], cti_object_types=["malware"])

In [12]:
p, r, f1, full_res = evaluator._evaluate_(predicted=processed_preds4eval, actual=outputs)
print(f"Precison: {p}\nRecall: {r}\nF1-Score: {f1}")

Precison: 0.27698
Recall: 0.88676
F1-Score: 0.42211
