In [1]:
import os
import re
import yaml
import json
import torch
import pickle
from unsloth import FastLanguageModel
from tqdm import tqdm
import pandas as pd

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 10-09 13:19:21 [__init__.py:244] Automatically detected platform cuda.


In [2]:
#sft_model = "/mnt/data/training-outputs/Llama-3.1-8B-Malware-Expert/checkpoint-271"
sft_model = "/mnt/data/training-outputs/Llama-3.1-8B-Malware-Expert-r128-a256/checkpoint-306"

#sft_model = "/home/deleftheriou/cti-model-training/Llama-3.1-8B-Instruct-DPO-Malware-Expert/checkpoint-393"

sft_system_message = """You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all malwares referenced or implied in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe malwares.
                    To describe a malware you should provide the fields id, type, name and is_family.
                    Instead of using UUID in the id field, use the rule type--name for generating ids.
                    If no malwares are identified return a json with an empty list "objects".
                    Identify all malwares in the folowing CTI report: """

# sft_system_message_2 = """You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
#                  Your task is to identify all malwares referenced or implied in a CTI report. 
#                  You MUST return a json with a field "objects" being a list of json objects that describe malwares.
#                  To describe a malware you should provide the fields id, type, name and is_family.
#                  Instead of using UUID in the id field, use the rule type--name for generating ids.
#                  For example, an output in which the malware RandomMalware is identified and is not family
#                  of some other malware should be like this:
                 
#                  {
#                      "objects": [
#                          {
#                              "id": "malware--RandomMalware",
#                              "type": "malware",
#                              "name": "RandomMalware",
#                              "is_family": false
#                          }
#                      ]
#                  }
                 
#                  If no malwares are identified return a json with an empty list "objects".
#                  Identify all malwares in the folowing CTI report: """

In [3]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = sft_model,
    fast_inference = False,
    load_in_4bit = False,
    max_seq_length = None,
    gpu_memory_utilization = 0.8
)

==((====))==  Unsloth 2025.6.8: Fast Llama patching. Transformers: 4.53.0. vLLM: 0.9.1.
   \\   /|    NVIDIA H100 PCIe. Num GPUs = 1. Max memory: 79.179 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Unsloth 2025.6.8 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [4]:
from transformers import TextStreamer

FastLanguageModel.for_inference(model)
text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

def format_input_prompt(system_message, user_input):
    formatted_input = [
        {"role": "assistant", "content": system_message},
        {"role": "user", "content": user_input}
    ]
    return formatted_input

def format_validation_example_for_inference(example):
    return example.split("<|start_header_id|>user<|end_header_id|>")[1].split("<|eot_id|><|start_header_id|>assistant<|end_header_id|>")[0]

def inference(model, system_message, user_input, max_new_tokens=None, **kwargs):
    input_ids = tokenizer.apply_chat_template(
        format_input_prompt(system_message, user_input),
        add_generation_prompt=True,
        return_tensors = "pt").to("cuda")
    if not max_new_tokens:
        max_new_tokens = model.config.max_position_embeddings - input_ids.shape[-1]
    model.generate(input_ids, streamer = text_streamer, max_new_tokens=max_new_tokens, **kwargs)

def predict(model, system_message, user_input, max_new_tokens=None, **kwargs):
    input_ids = tokenizer.apply_chat_template(
        format_input_prompt(system_message, user_input),
        add_generation_prompt=True,
        return_tensors = "pt").to("cuda")
    if not max_new_tokens:
        max_new_tokens = model.config.max_position_embeddings - input_ids.shape[-1]
    
    output_ids = model.generate(input_ids, max_new_tokens=max_new_tokens, **kwargs)
    result = tokenizer.batch_decode(output_ids)
    processed_result = result[0].split("<|start_header_id|>assistant<|end_header_id|>\n\n")[-1].split("<|eot_id|>")[0]
    return processed_result

In [5]:
def load_json(path:str, filename:str):
    with open(os.path.join(path, filename), mode="r", encoding="utf-8") as f:
        return json.load(f)
    
def format_example(example:dict, system_message):
        formatted_example = [
            {"role": "assistant", "content": system_message},
            {"role": "user", "content": example["input"]},
            {"role": "assistant", "content": json.dumps(example["output"])}
        ]
        return formatted_example

In [6]:
test_path = "/mnt/data/openCTI/splitted-io-pairs/test"
inputs = []
outputs = []
include_cti_type = ["malware"]

for file in os.listdir(test_path):
    cti_type = file.split("--")[0]
    if cti_type not in include_cti_type:
        continue
    example = load_json(test_path, file)
    inputs.append(example["input"])
    outputs.append(example["output"])

In [None]:
print(outputs[3])

In [None]:
system_message = sft_system_message
user_input = inputs[3]
inference(model,
          system_message, 
          user_input, 
          max_new_tokens=500,
          temperature=0.9,
          top_p=0.9,
          repetition_penalty=1.1,
          no_repeat_ngram_size=3,
          do_sample=True)

{"objects": [{"id": "", "type": "malware", "name": "BlueSky", "is_family": false}]}


In [7]:
system_message = sft_system_message
inputs = inputs[:50]
outputs = outputs[:50]

preds = [predict(model,
                 system_message,
                 user_input,
                 max_new_tokens=500,
                 temperature=0.6,
                 top_p=0.2,
                 repetition_penalty=1.1,
                 no_repeat_ngram_size=3,
                 do_sample=True) for user_input in tqdm(inputs)]

  0%|          | 0/50 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
LlamaForCausalLM has no `_prepare_4d_causal_attention_mask_with_cache_position` method defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're writing code, see Llama for an example implementation. If you're a user, please report this issue on GitHub.
100%|██████████| 50/50 [01:40<00:00,  2.00s/it]


In [8]:
preds4eval = []
failed_preds = []
pattern = r'\{\s*"id"\s*:\s*"[^"]*"\s*,\s*"type"\s*:\s*"[^"]*"\s*,\s*"name"\s*:\s*"[^"]*"\s*,\s*"is_family"\s*:\s*(?:true|false|null|"(?:[^"]*)"|[-+]?\d+(?:\.\d+)?)\s*\}'

for p in preds:
    try:
        preds4eval.append(
            {
                "objects":json.loads(p.lower())["objects"]
            }
        )
    except:
        objects = re.findall(pattern, p)
        if not objects:
            failed_preds.append(p)
        else:
            valid_objs = [json.loads(obj) for obj in objects]
        preds4eval.append(
            {
                "objects":valid_objs
                }
        )

In [9]:
print(f"Percenrage of failed json outputs: {'{:.1f}'.format(100 * len(failed_preds) / len(inputs))}%")

Percenrage of failed json outputs: 14.0%


In [10]:
# Post processing
processed_preds4eval = []

def fix_malware_id(wrong_id: str) -> str:
    """
    Fixes malformed malware IDs according to the observed patterns.
    """

    # Trim spaces
    s = wrong_id.strip().lower()

    # Remove leading underscores or hyphens
    s = re.sub(r'^[-_]+', '', s)

    # Remove redundant 'malware' if it’s at the start but malformed
    s = re.sub(r'^(malware[-_]+)', '', s)

    # Handle duplicated name parts (e.g., 'fatboy--fatboy')
    parts = re.split(r'--+', s)
    if len(parts) == 2 and parts[0] == parts[1]:
        s = parts[0]

    # Prepend 'malware--'
    corrected = f"malware--{s}"

    # Ensure only one double dash after 'malware'
    corrected = re.sub(r'^malware-+', 'malware--', corrected)

    return corrected

# Step 1
for p in preds4eval:
    objects = []
    for obj in p["objects"]:
        # Step 1
        if "id" in obj.keys():
            ID = fix_malware_id(obj["id"].strip())
            NAME = ID.split("malware--")[-1]

        # if "name" in obj.keys():
        #     NAME = obj["name"]

        if "is_family" in obj.keys():
            IS_FAMILY = obj["is_family"]
        else:
            IS_FAMILY = False

        objects.append(
            {
                "id":ID,
                "type":"malware",
                "name":NAME,
                "is_family":IS_FAMILY
            }
        )

    processed_preds4eval.append(
                {
                    "objects":objects
                }
            )

In [11]:
from evaluation.stix_evaluator import STIXEvaluator

evaluator = STIXEvaluator(comparison_values=["id"], cti_object_types=["malware"])

In [12]:
p, r, f1, full_res = evaluator._evaluate_(predicted=processed_preds4eval, actual=outputs)
print(f"Precison: {p}\nRecall: {r}\nF1-Score: {f1}")

# Temperature 0.2
# Min_p 0.2
# Precison: 0.7
# Recall: 0.4375
# F1-Score: 0.53846

# Temperature 0.2
# Min_p 0.1
# Precison: 0.7
# Recall: 0.4375
# F1-Score: 0.53846

# Temperature 0.7
# Min_p 0.2
# Precison: 0.7
# Recall: 0.4375
# F1-Score: 0.53846

# Temperature 0.1
# Min_p 0.1
# Precison: 0.7
# Recall: 0.4375
# F1-Score: 0.53846


Precison: 0.56757
Recall: 0.32812
F1-Score: 0.41584


Type and Name will be used to compare stix objects!
All cti types will be evaluated!


In [None]:
print(f"Detailed results:\n {full_res}")