In [None]:
import os
import re
import json

import warnings
warnings.filterwarnings("ignore")

from multiprocessing import cpu_count
num_proc = cpu_count()

import yaml

from dataprep.stix.StixConfig import StixToPydanticMap, STIX, CustomSTIX
from pydantic import BaseModel, ValidationError


from evaluation.stix_evaluator import STIXEvaluator

from unsloth import FastLanguageModel, is_bfloat16_supported
from unsloth.chat_templates import get_chat_template
import torch

from trl import GRPOConfig, GRPOTrainer

from data_processor import SplittedJsonIoDataset
from customs import customize_tokenizer

In [None]:
# Clear GPU cache
torch.cuda.empty_cache()

In [None]:
with open("grpo_config.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.SafeLoader)

In [None]:
# model, tokenizer = FastLanguageModel.from_pretrained(
#     model_name = "/mnt/data/training-outputs/Llama/Llama-3.1-8B-Instruct-Not-Quantized/checkpoint-190",
#     fast_inference = True,
#     load_in_4bit = False,
#     max_seq_length = None,
#     gpu_memory_utilization = 0.7
# )
# model = model.merge_and_unload()
# model.save_pretrained("grpo_model_input")
# tokenizer.save_pretrained("grpo_model_input")

In [None]:
# # Check if model is a peft model
# import peft.helpers

# def load_model_and_tokenizer(model_name_or_path, config, **kwargs):
#     if peft.helpers.check_if_peft_model(model_name_or_path):
#         model, tokenizer = FastLanguageModel.from_pretrained(
#                 model_name=model_name_or_path,
#                 **kwargs
#             )
#         if config["merge_peft_model"]:
#             model.merge_and_unload()
#     else:
#         model, tokenizer = FastLanguageModel.from_pretrained(
#             **config["model_loading_args"]
#         )
#     return model, tokenizer

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "grpo_model_input",
    fast_inference = True,
    load_in_4bit = True,
    max_seq_length = None,
    gpu_memory_utilization = 0.7
)

In [None]:
model.max_seq_length

In [None]:
#model = model.merge_and_unload()
#model

In [None]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 32,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

In [None]:
#model, tokenizer = customize_tokenizer(model, tokenizer, config)
# no need for deepseek
# config["chat_template"] = "deepseek"
# tokenizer = get_chat_template(tokenizer, config["chat_template"])

In [None]:
# Create dataset for training
dataset = SplittedJsonIoDataset(tokenizer, config).grpo_create()

In [None]:
from datasets import concatenate_datasets
# Concatenate eval and train dataset to increase the learning examples
train_dataset = concatenate_datasets([dataset["train"], dataset["eval"]])

In [None]:
def deserialize_answer(answer: str) -> dict:
    return json.loads(answer)

def deserialize_response_for_evaluation(answer: str) -> dict:
    if is_stix_bundle(answer):
        return json.loads(answer)
    else:
        return {"id":"", "type":"bundle", "objects":[]}

def extract_xml_answer(response: str) -> str:
    answer = response.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def is_stix_bundle(text: str) -> bool:
    try:
        bundle = json.loads(text)
        pydantic_stix_bundle = STIX(**bundle)
        return True
    except:
        return False
    
def is_custom_stix_bundle(text: str) -> bool:
    try:
        bundle = json.loads(text)
        pydantic_stix_bundle = CustomSTIX(**bundle)
        return True
    except:
        return False

def count_valid_stix_objects(text: str) -> bool:
    smap = StixToPydanticMap()
    cnt = 0.0
    if is_stix_bundle(text):
        bundle = json.loads(text)
        for obj in bundle["objects"]:
            try:
                smap(obj)
                cnt += 1
            except:
            #except ValidationError:
                pass
        return cnt / len(bundle["objects"])
    else:
      return cnt

In [None]:
def format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def stix_validity_reward_func(completions, answers, **kwargs) -> list[float]:
    """Reward function that checks if the completion can is a stix bundle."""
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if is_stix_bundle(r) else 0.0 for r in extracted_responses]

def custom_stix_validity_reward_func(completions, answers, **kwargs) -> list[float]:
    """Reward function that checks if the completion can is a stix bundle."""
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if is_custom_stix_bundle(r) else 0.0 for r in extracted_responses]

def stix_objects_validity_reward_func(completions, answers, **kwargs) -> list[float]:
    """Reward function that checks if the completion has valid stix objects."""
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 * count_valid_stix_objects(r) for r in extracted_responses]

def accuracy_reward_func(completions, answers, **kwargs) -> list[float]:
    evaluator = STIXEvaluator()
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    desirialized_responses = [deserialize_response_for_evaluation(r) for r in extracted_responses]
    desirialized_answers = [deserialize_answer(a) for a in answers]
    return [evaluator.evaluate_single(r, a)[2] for r, a in zip(desirialized_responses, desirialized_answers)]

In [None]:
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 2, # Decrease if out of memory
    max_prompt_length = 8000, #config["model_loading_args"]["max_seq_length"],
    max_completion_length = 8000, #config["model_loading_args"]["max_seq_length"],
    num_train_epochs = 2, # Set to 1 for a full training run
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "tensorboard", # Can use Weights & Biases
    output_dir = "grpo_outputs",
    ###############################
    temperature=0.7,
    top_p=0.6,
    repetition_penalty=1.1
)

In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        format_reward_func,
        stix_validity_reward_func,
        custom_stix_validity_reward_func,
        stix_objects_validity_reward_func,
        accuracy_reward_func
    ],
    args = training_args,
    train_dataset = train_dataset
)

In [None]:
#!PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

In [None]:
from unsloth import unsloth_train
# Start training
trainer.train()
#trainer_stats = unsloth_train(trainer)