In [1]:
import os
import re
import json

from multiprocessing import cpu_count
num_proc = cpu_count()

import yaml

from dataprep.stix.StixConfig import StixToPydanticMap, STIX
from pydantic import BaseModel, ValidationError


from evaluation.stix_evaluator import STIXEvaluator

from unsloth import FastLanguageModel, is_bfloat16_supported
from unsloth.chat_templates import get_chat_template
import torch

from trl import GRPOConfig, GRPOTrainer

from data_processor import SplittedJsonIoDataset
from customs import customize_tokenizer

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 07-04 13:25:18 [__init__.py:244] Automatically detected platform cuda.


In [2]:
# Clear GPU cache
torch.cuda.empty_cache()

In [3]:
with open("grpo_config.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.SafeLoader)

In [4]:
# # Check if model is a peft model
# import peft.helpers

# def load_model_and_tokenizer(model_name_or_path, config, **kwargs):
#     if peft.helpers.check_if_peft_model(model_name_or_path):
#         model, tokenizer = FastLanguageModel.from_pretrained(
#                 model_name=model_name_or_path,
#                 **kwargs
#             )
#         if config["merge_peft_model"]:
#             model.merge_and_unload()
#     else:
#         model, tokenizer = FastLanguageModel.from_pretrained(
#             **config["model_loading_args"]
#         )
#     return model, tokenizer

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="meta-llama/Llama-3.1-8B-Instruct", #deepseek-ai/DeepSeek-R1-0528-Qwen3-8B", #"/mnt/data/training-outputs/first-run/outputs/checkpoint-194",
    fast_inference = True,
    load_in_4bit = True,
    max_seq_length = None,
    gpu_memory_utilization = 0.7
    #65536
)

==((====))==  Unsloth 2025.6.8: Fast Llama patching. Transformers: 4.53.0. vLLM: 0.9.1.
   \\   /|    NVIDIA H100 PCIe. Num GPUs = 1. Max memory: 79.19 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 9.0. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/llama-3.1-8b-instruct-unsloth-bnb-4bit with actual GPU utilization = 69.53%
Unsloth: Your GPU has CUDA compute capability 9.0 with VRAM = 79.19 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 131072. Num Sequences = 368.
Unsloth: vLLM's KV Cache can use up to 48.73 GB. Also swap space = 6 GB.
INFO 07-04 13:25:32 [config.py:823] This model supports multiple tasks: {'reward', 'generate', 'embed', 'classify', 'score'}. Defaulting to 'generate'.
INFO 07-04 13:25:33 [config.py:2195

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 07-04 13:25:38 [punica_selector.py:19] Using PunicaWrapperGPU.
INFO 07-04 13:25:39 [gpu_model_runner.py:1624] Model loading took 5.9525 GiB and 2.490744 seconds
INFO 07-04 13:25:50 [backends.py:462] Using cache directory: /home/deleftheriou/.cache/vllm/torch_compile_cache/7434f386e1/rank_0_0 for vLLM's torch.compile
INFO 07-04 13:25:50 [backends.py:472] Dynamo bytecode transform time: 10.35 s
INFO 07-04 13:25:55 [backends.py:135] Directly load the compiled graph(s) for shape None from the cache, took 4.473 s
INFO 07-04 13:26:04 [monitor.py:34] torch.compile takes 10.35 s in total
INFO 07-04 13:26:05 [gpu_worker.py:227] Available KV cache memory: 32.07 GiB
INFO 07-04 13:26:05 [kv_cache_utils.py:715] GPU KV cache size: 262,704 tokens
INFO 07-04 13:26:05 [kv_cache_utils.py:719] Maximum concurrency for 131,072 tokens per request: 2.00x
INFO 07-04 13:27:08 [gpu_model_runner.py:2048] Graph capturing finished in 63 secs, took 1.60 GiB
INFO 07-04 13:27:08 [core.py:171] init engine (profil

In [5]:
model.max_seq_length

131072

In [6]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 8, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 8,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

Unsloth 2025.6.8 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [7]:
#model, tokenizer = customize_tokenizer(model, tokenizer, config)
# no need for deepseek
# config["chat_template"] = "deepseek"
# tokenizer = get_chat_template(tokenizer, config["chat_template"])

In [8]:
# Create dataset for training
dataset = SplittedJsonIoDataset(tokenizer, config).grpo_create()

Map:   0%|          | 0/1564 [00:00<?, ? examples/s]

Map:   0%|          | 0/207 [00:00<?, ? examples/s]

In [9]:
def deserialize_answer(answer: str) -> dict:
    return json.loads(answer)

def deserialize_response_for_evaluation(answer: str) -> dict:
    if is_stix_bundle(answer):
        return json.loads(answer)
    else:
        return {"id":"", "type":"bundle", "objects":[]}

def extract_xml_answer(response: str) -> str:
    answer = response.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def is_stix_bundle(text: str) -> bool:
    try:
        bundle = json.loads(text)
        pydantic_stix_bundle = STIX(**bundle)
        return True
    except:
        return False

def count_valid_stix_objects(text: str) -> bool:
    smap = StixToPydanticMap()
    cnt = 0.0
    if is_stix_bundle(text):
        bundle = json.loads(text)
        for obj in bundle["objects"]:
            try:
                smap(obj)
                cnt += 1
            except ValidationError:
                pass
        return cnt / len(bundle["objects"])
    else:
      return cnt

In [10]:
def format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def stix_validity_reward_func(completions, answers, **kwargs) -> list[float]:
    """Reward function that checks if the completion can is a stix bundle."""
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if is_stix_bundle(r) else 0.0 for r in extracted_responses]

def stix_objects_validity_reward_func(completions, answers, **kwargs) -> list[float]:
    """Reward function that checks if the completion has valid stix objects."""
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 * count_valid_stix_objects(r) for r in extracted_responses]

def accuracy_reward_func(completions, answers, **kwargs) -> list[float]:
    evaluator = STIXEvaluator()
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    desirialized_responses = [deserialize_response_for_evaluation(r) for r in extracted_responses]
    desirialized_answers = [deserialize_answer(a) for a in answers]
    return [evaluator.evaluate_single(r, a)[2] for r, a in zip(desirialized_responses, desirialized_answers)]

In [11]:
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = 16384, #config["model_loading_args"]["max_seq_length"],
    max_completion_length = 16384, #config["model_loading_args"]["max_seq_length"],
    num_train_epochs = 1, # Set to 1 for a full training run
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "tensorboard", # Can use Weights & Biases
    output_dir = "grpo_outputs",
)

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 4


In [12]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        format_reward_func,
        stix_validity_reward_func,
        stix_objects_validity_reward_func,
        accuracy_reward_func
    ],
    args = training_args,
    train_dataset = dataset["train"],
    #eval_dataset = dataset["eval"]
)

In [None]:
from unsloth import unsloth_train
# Start training
trainer.train()
#trainer_stats = unsloth_train(trainer)

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,564 | Num Epochs = 1 | Total steps = 1,564
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 1 x 1) = 4
 "-____-"     Trainable parameters = 20,971,520/8,000,000,000 (0.26% trained)
