In [4]:
import os
import re
import yaml
import json
import torch
import pickle
from unsloth import FastLanguageModel
from tqdm import tqdm

In [5]:
base_model = "meta-llama/Llama-3.1-8B-Instruct"
sft_model = "/mnt/data/training-outputs/Llama/Llama-3.1-8B-Instruct-Not-Quantized/checkpoint-190"
grpo_model = "/mnt/data/training-outputs/LlamaGRPO/grpo_outputs/checkpoint-400"

In [6]:
with open("grpo_config.yaml", "r") as f:
    grpo_config = yaml.load(f, Loader=yaml.SafeLoader)

with open("config.yaml", "r") as f:
    sft_config = yaml.load(f, Loader=yaml.SafeLoader)

grpo_system_message = grpo_config["system_message"]
sft_system_message = sft_config["system_message"]
base_system_message = sft_system_message

In [7]:
# model, tokenizer = FastLanguageModel.from_pretrained(
#     model_name = "/mnt/data/training-outputs/Llama/Llama-3.1-8B-Instruct-Not-Quantized/checkpoint-190",
#     load_in_4bit = False,
#     max_seq_length = None
# )
# model = model.merge_and_unload()
# model.save_pretrained("grpo_model_input")
# tokenizer.save_pretrained("grpo_model_input")

In [8]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/Llama-3.3-70B-Instruct",
    fast_inference = False,
    load_in_4bit = True,
    max_seq_length = None,
    gpu_memory_utilization = 0.8
)

==((====))==  Unsloth 2025.6.8: Fast Llama patching. Transformers: 4.53.0. vLLM: 0.9.1.
   \\   /|    NVIDIA H100 PCIe. Num GPUs = 1. Max memory: 79.19 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

model-00001-of-00008.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00002-of-00008.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00008.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00004-of-00008.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00005-of-00008.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00006-of-00008.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00007-of-00008.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00008-of-00008.safetensors:   0%|          | 0.00/4.75G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/234 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/454 [00:00<?, ?B/s]

In [6]:
model = model.for_inference()

In [7]:
def load_single_example(path:str, filename:str):
    with open(os.path.join(path, filename), mode="r", encoding="utf-8") as f:
        return json.load(f)
    
test_data_path = "/mnt/data/openCTI/io-pairs/test/"
test_data = [load_single_example(test_data_path, filename) for filename in os.listdir(test_data_path)]
test_inputs = [example["input"] for example in test_data]
test_outputs = [example["output"] for example in test_data]

In [8]:
def format_input_prompt(system_message, user_input):
    formatted_input = [
        {"role": "assistant", "content": system_message},
        {"role": "user", "content": user_input}
    ]
    return formatted_input

def sft_post_process(text):
    text = text.split(sft_config["response_part"])[-1]
    text = text[2:] if text[:2]=="\n\n" else text
    text = re.sub(r'<\|eot_id\|>', '', text)
    return text

def inference(model, system_message, user_input, max_new_tokens=None, **kwargs):
    input_ids = tokenizer.apply_chat_template(
        format_input_prompt(system_message, user_input),
        add_generation_prompt=True,
        return_tensors = "pt").to("cuda")
    if not max_new_tokens:
        max_new_tokens = model.config.max_position_embeddings - input_ids.shape[-1]
    output_ids = model.generate(input_ids, max_new_tokens=max_new_tokens, **kwargs)
    output_text = tokenizer.batch_decode(output_ids)[0]
    return sft_post_process(output_text)

In [9]:
preds = [inference(model,
                   sft_system_message, 
                   example, 
                   max_new_tokens=16384,
                   temperature=0.7,
                   top_p=0.6,
                   repetition_penalty=1.1,
                   no_repeat_ngram_size=3,
                   do_sample=True) for example in tqdm(test_inputs)]

with open('sft_16384_token_limit_preds.pkl', 'wb') as file:
    pickle.dump(preds, file)

  0%|          | 0/214 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
LlamaForCausalLM has no `_prepare_4d_causal_attention_mask_with_cache_position` method defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're writing code, see Llama for an example implementation. If you're a user, please report this issue on GitHub.
100%|██████████| 214/214 [10:38:15<00:00, 178.95s/it]   


In [10]:
# from transformers import TextStreamer

# text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

# def stream_inference(model, system_message, user_input, max_new_tokens=None, **kwargs):
#     input_ids = tokenizer.apply_chat_template(
#         format_input_prompt(system_message, user_input),
#         add_generation_prompt=True,
#         return_tensors = "pt").to("cuda")
#     if not max_new_tokens:
#         max_new_tokens = model.config.max_position_embeddings - input_ids.shape[-1]
#     model.generate(input_ids, streamer = text_streamer, max_new_tokens=max_new_tokens, **kwargs)

In [11]:
# stream_inference(model,
#         grpo_system_message, 
#         test_inputs[2], 
#         max_new_tokens=32768,
#         temperature=0.6,
#         top_p=0.4,
#         repetition_penalty=1.1,
#         no_repeat_ngram_size=3,
#         do_sample=True)

In [22]:
print(preds[0])

{"id": "", "type": "bundle", "objects": [{"id": "report--New Malvolent PyPI packages used by Lazurus", "type': "report', "name': 'New Malovolent PyPI packets used by Lazorus", "description': 'PyPI packages released to PyPi by Lzorus', "labels': ['python', 'lazurus', 'pypiconf','swampool', 'quasarlub', 'pycryptovn', 'typosquatling'], "report_types': ['threat-report'], "created': '2024-02-29 18:22', "object_refs': ['report--Lazurus', "indicator--https://blockchan-newtech.con/download/dowload.asap', "malware--comebacker', "attack-pattern--T1573', "intrusion-set--Lazarus', "vulnerability--CVE-2023-27362', "location--Europe', "identity--Python', "file--e88528ace23092bas628523564ad8abc', "domain-name--chaingrowen.com', "url--http://91,206,178,125/upload/uplod/asap', 'domain-name-->blockchain.newtech.com', 'url-->https://fastte.com/user/agencys.ap', 'attack-pattern-->T1064','malware-->pycryptocon', 'location-->Asia', 'vulnerabilty-->CVE-2018-1333', 'identity-->QuasarLib', 'file-->b4a048450bb7

In [23]:
json.dumps(test_outputs[0])

'{"id": "", "type": "bundle", "objects": [{"id": "report--New Malicious PyPI Packages used by Lazarus", "type": "report", "name": "New Malicious PyPI Packages used by Lazarus", "description": "JPCERT/CC confirmed that Lazarus has released malicious Python packages to PyPI, the official Python repository. The packages pycryptoenv, pycryptoconf, quasarlib, and swapmempool contain malware. The package names pycryptoenv and pycryptoconf target typos when installing legitimate packages. The malware is Comebacker, which decodes and executes a DLL sending HTTP requests to C2 servers. The DLL receives and runs executable files. The packages were downloaded 300 to 1200 times, showing Lazarus targets typos for infection.", "labels": ["python", "pypi", "lazarus", "typosquatting", "pycryptoconf", "swapmempool", "quasarlib", "pycryptoenv", "comebacker"], "report_types": ["threat-report"], "created": "2024-02-29 18:22:46.516000+00:00", "object_refs": ["report--New Malicious PyPI Packages used by Laz