In [1]:
import os
import re
import yaml
import json
import torch
import pickle
from unsloth import FastLanguageModel
from tqdm import tqdm
import pandas as pd

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 09-25 07:39:03 [__init__.py:244] Automatically detected platform cuda.


In [2]:
sft_model = "/mnt/data/training-outputs/Llama-3.1-8B-Malware-Expert/checkpoint-129"

sft_system_message = """You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all malwares referenced or implied in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe malwares.
                    To describe a malware you should provide the fields id, type, name, description[Optional], malware_types[Optional], is_family[Optional], aliases[Optional], os_execution_envs[Optional], architecture_execution_envs[Optional], implementation_languages[Optional].
                    Instead of using UUID in the id field, use the rule type--name for generating ids.
                    If no malwares are identified return a json with an empty list "objects".
                    Identify all malwares in the folowing CTI report: """

In [3]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = sft_model,
    fast_inference = False,
    load_in_4bit = False,
    max_seq_length = None,
    gpu_memory_utilization = 0.8
)

==((====))==  Unsloth 2025.6.8: Fast Llama patching. Transformers: 4.53.0. vLLM: 0.9.1.
   \\   /|    NVIDIA H100 PCIe. Num GPUs = 1. Max memory: 79.179 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Unsloth 2025.6.8 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [38]:
from transformers import TextStreamer

FastLanguageModel.for_inference(model)
text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

def format_input_prompt(system_message, user_input):
    formatted_input = [
        {"role": "assistant", "content": system_message},
        {"role": "user", "content": user_input}
    ]
    return formatted_input

def format_validation_example_for_inference(example):
    return example.split("<|start_header_id|>user<|end_header_id|>")[1].split("<|eot_id|><|start_header_id|>assistant<|end_header_id|>")[0]

def inference(model, system_message, user_input, max_new_tokens=None, **kwargs):
    input_ids = tokenizer.apply_chat_template(
        format_input_prompt(system_message, user_input),
        add_generation_prompt=True,
        return_tensors = "pt").to("cuda")
    if not max_new_tokens:
        max_new_tokens = model.config.max_position_embeddings - input_ids.shape[-1]
    model.generate(input_ids, streamer = text_streamer, max_new_tokens=max_new_tokens, **kwargs)

def predict(model, system_message, user_input, max_new_tokens=None, **kwargs):
    input_ids = tokenizer.apply_chat_template(
        format_input_prompt(system_message, user_input),
        add_generation_prompt=True,
        return_tensors = "pt").to("cuda")
    if not max_new_tokens:
        max_new_tokens = model.config.max_position_embeddings - input_ids.shape[-1]
    
    output_ids = model.generate(input_ids, max_new_tokens=max_new_tokens, **kwargs)
    result = tokenizer.batch_decode(output_ids)
    processed_result = result[0].split("<|start_header_id|>assistant<|end_header_id|>\n\n")[-1].split("<|eot_id|>")[0]
    return processed_result

In [5]:
def load_json(path:str, filename:str):
    with open(os.path.join(path, filename), mode="r", encoding="utf-8") as f:
        return json.load(f)
    
def format_example(example:dict, system_message):
        formatted_example = [
            {"role": "assistant", "content": system_message},
            {"role": "user", "content": example["input"]},
            {"role": "assistant", "content": json.dumps(example["output"])}
        ]
        return formatted_example

In [None]:
test_path = "/mnt/data/openCTI/splitted-io-pairs/test"
inputs = []
outputs = []
include_cti_type = ["malware"]

for file in os.listdir(test_path):
    cti_type = file.split("--")[0]
    if cti_type not in include_cti_type:
        continue
    example = load_json(test_path, file)
    inputs.append(example["input"])
    outputs.append(example["output"])

In [24]:
print(outputs[3])

{'objects': [{'id': 'malware--BlueSky', 'type': 'malware', 'name': 'BlueSky', 'is_family': False}]}


In [25]:
system_message = sft_system_message
user_input = inputs[3]
inference(model,
          system_message, 
          user_input, 
          max_new_tokens=None,
          temperature=0.1,
          top_p=0.6,
          repetition_penalty=1.1,
          no_repeat_ngram_size=3,
          do_sample=True)

{"objects": [{"id": "malware--BlueSky", "type": "maleware", "name": "Bluesky", "is_family": false}]}


In [39]:
system_message = sft_system_message
user_input = inputs[3]

pred = predict(model,
        system_message, 
        user_input, 
        max_new_tokens=None,
        temperature=0.1,
        top_p=0.6,
        repetition_penalty=1.1,
        no_repeat_ngram_size=3,
        do_sample=True)

In [52]:
preds = [predict(model,
                 system_message,
                 user_input,
                 max_new_tokens=None,
                 temperature=0.1,
                 top_p=0.6,
                 repetition_penalty=1.1,
                 no_repeat_ngram_size=3,
                 do_sample=True) for user_input in tqdm(inputs[:200])]

100%|██████████| 200/200 [07:46<00:00,  2.33s/it]


In [54]:
dict_preds = []
for p in preds:
    try:
        dict_preds.append(json.loads(p))
    except:
        print(p)
        dict_preds.append({})

{"objects": [{"id": "malware--Arid gopher", "type": "maleware", "name": "Arid gofer", "is_family": false}, {"id": "_malware__MicropsIA", "Type": "MALWARE", "Name": "Microps IA", "Is_family": False}]}
{"objects": [{"id": "malware--BlackMagic", "type": "maleware", "name": "BlackMagic', 'is_family': false}]}
{"objects": [{"id": "malware--SpyNote", "type": "maleware", "name": "SpyNote', "is_family": false}]}
{"objects": [{"id": "malware--Zloader", "type": " malware", "name": "Zloader', "is_family": false}]}
{"objects": [{"id": "malware--MyKings", "type": "maleware", "name": "MyKins", "is_family": false}, {"id": "-malware-CosmicStrands", "title": "Cosmic Strands", "_type": "-object", "id": "", "type_: "-malwares", "fields": {}, "is_folder": false, "is_template": false}]}
{"objects": [{"id": "malware--Bahamute Spyware", "type": "maleware", 'name': 'Bahamote Spyware', 'is_family': false}]}
{"objects": [{"id": "malware--GootKit", "type": "maleware", "name": "Gootkit", "is_family": false}, {"id

In [26]:
from evaluation.stix_evaluator import STIXEvaluator

evaluator = STIXEvaluator(comparison_values=["id"], cti_object_types=["malware"])

In [None]:
evaluator._evaluate_(predicted=dict_preds, actual=outputs)