In [None]:
import warnings
warnings.filterwarnings("ignore")

from pprint import pprint

import json

import os

from unsloth import FastLanguageModel
import torch

from multiprocessing import cpu_count
num_proc = cpu_count()

import yaml

from data_processor import SplittedJsonIoDataset
from customs import customize_tokenizer

from unsloth import UnslothTrainer, UnslothTrainingArguments

from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import TrainingArguments, DataCollatorForSeq2Seq, DataCollatorForLanguageModeling
from unsloth import is_bfloat16_supported

from unsloth.chat_templates import train_on_responses_only

from unsloth import unsloth_train

from utils import save_log_history

In [None]:
def load_json(path:str, filename:str):
    with open(os.path.join(path, filename), mode="r", encoding="utf-8") as f:
        return json.load(f)
    
def format_example(example:dict, system_message):
        formatted_example = [
            {"role": "assistant", "content": system_message},
            {"role": "user", "content": example["input"]},
            {"role": "assistant", "content": json.dumps(example["output"])}
        ]
        return formatted_example

In [None]:
system_messages = {
    "domain-name":"""You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all domain names referenced in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe domain names.
                    To describe a domain name you should provide the fields id, type and value.
                    Instead of using UUID in the id field, use the rule type--value for generating ids.
                    If no domain names are identified return a json with an empty list "objects".
                    Identify all domain names in the folowing CTI report: """,

    "hostname":"""You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all hostnames referenced in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe hostnames.
                    To describe a hostname you should provide the fields id, type and value.
                    Instead of using UUID in the id field, use the rule type--value for generating ids.
                    If no hostnames are identified return a json with an empty list "objects".
                    Identify all hostnames in the folowing CTI report: """,

    "url":"""You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all URLs referenced in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe URLs.
                    To describe a URL you should provide the fields id, type and value.
                    Instead of using UUID in the id field, use the rule type--value for generating ids.
                    If no URLs are identified return a json with an empty list "objects".
                    Identify all URLs in the folowing CTI report: """,

    "email-addr":"""You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all email addresses referenced in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe email addresses.
                    To describe an email address you should provide the fields id, type and value.
                    Instead of using UUID in the id field, use the rule type--value for generating ids.
                    If no email addresses are identified return a json with an empty list "objects".
                    Identify all email addresses in the folowing CTI report: """,

    "ipv4-addr":"""You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all ipv4-addresses referenced in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe ipv4-addresses.
                    To describe an ipv4-address you should provide the fields id, type and value.
                    Instead of using UUID in the id field, use the rule type--value for generating ids.
                    If no ipv4-addresses are identified return a json with an empty list "objects".
                    Identify all ipv4-addresses in the folowing CTI report: """,

    "cryptocurrency-wallet":"""You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all cryptocurrency-wallets referenced in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe cryptocurrency-wallets.
                    To describe a cryptocurrency-wallet you should provide the fields id, type and value.
                    Instead of using UUID in the id field, use the rule type--value for generating ids.
                    If no cryptocurrency-wallet are identified return a json with an empty list "objects".
                    Identify all cryptocurrency-wallet in the folowing CTI report: """,

    "indicator":"""You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all indicators referenced or implied in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe indicators.
                    To describe an indicator you should provide the fields id, type, name, description: Optional, indicator_types: Optional[list], pattern: str, pattern_type: Literal["stix", "snort", "yara"].
                    Instead of using UUID in the id field, use the rule type--name for generating ids.
                    If no indicators are identified return a json with an empty list "objects".
                    Identify all indicators in the folowing CTI report: """, 

    "file":"""You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all files referenced in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe files.
                    To describe a file you should provide the fields id, type, name[Optional], hashes: dict, size[Optional], mime_type[Optional].
                    Instead of using UUID in the id field, use the rule type--hashes for generating ids.
                    If no files are identified return a json with an empty list "objects".
                    Identify all files in the folowing CTI report: """,  

    "attack-pattern":"""You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all attack-patterns referenced or implied in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe attack-patterns.
                    To describe an attack-pattern you should provide the fields id, type, name, description[Optional], aliases[Optional].
                    Instead of using UUID in the id field, use the rule type--name for generating ids.
                    If no attack-patterns are identified return a json with an empty list "objects".
                    Identify all attack-patterns in the folowing CTI report: """, 

    "identity":"""You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all identities referenced in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe identities.
                    To describe an identity you should provide the fields id, type, name, description[Optional].
                    Instead of using UUID in the id field, use the rule type--name for generating ids.
                    If no identities are identified return a json with an empty list "objects".
                    Identify all identities in the folowing CTI report: """, 

    "malware":"""You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all malwares referenced or implied in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe malwares.
                    To describe a malware you should provide the fields id, type, name, description[Optional], malware_types[Optional], is_family[Optional], aliases[Optional], os_execution_envs[Optional], architecture_execution_envs[Optional], implementation_languages[Optional].
                    Instead of using UUID in the id field, use the rule type--nasme for generating ids.
                    If no malwares are identified return a json with an empty list "objects".
                    Identify all malwares in the folowing CTI report: """, 

    "location":"""You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all locations referenced or implied in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe locations.
                    To describe a location you should provide the fields id, type, name, country, description[Optional], latitude[Optional], longtitude[Optional], city[Optional].
                    Instead of using UUID in the id field, use the rule type--name for generating ids.
                    If no locations are identified return a json with an empty list "objects".
                    Identify all locations in the folowing CTI report: """, 

    "vulnerability":"""You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all vulnerabilities referenced or implied in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe vulnerabilities.
                    To describe a vulnerability you should provide the fields id, type, name, description[Optional].
                    Instead of using UUID in the id field, use the rule type--name for generating ids.
                    If no vulnerabilities are identified return a json with an empty list "objects".
                    Identify all vulnerabilities in the folowing CTI report: """,

    "intrusion-set":"""You are an AI Security Analyst in Cyberthreat Intelligence (CTI). 
                    Your task is to identify all intrusion-sets referenced or implied in a CTI report. 
                    You MUST return a json with a field "objects" being a list of json objects 
                    that describe intrusion-sets.
                    To describe an intrusion-set you should provide the fields id, type, name, description[Optional], aliases[Optional], goals[Optional], resource_level[Optional], primary_motivation[Optional], secondary_motivation[Optional].
                    Instead of using UUID in the id field, use the rule type--name for generating ids.
                    If no intrusion-sets are identified return a json with an empty list "objects".
                    Identify all intrusion-sets in the folowing CTI report: """,

}

In [None]:
with open("config.yaml", "r") as f:
    config = yaml.load(f, Loader=yaml.SafeLoader)

# model, tokenizer = FastLanguageModel.from_pretrained(
#     **config["model_loading_args"]
# )

#model, tokenizer = customize_tokenizer(model, tokenizer, config)

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "/home/deleftheriou/cti-model-training/Llama-3.1-8B-Instruct-CTI-Subtasks/checkpoint-291",
    fast_inference = False,
    load_in_4bit = True,
    max_seq_length = None,
    gpu_memory_utilization = 0.8
)

In [None]:
for name, param in model.named_parameters():
    if name in ["base_model.model.lm_head.modules_to_save.default.weight", "base_model.model.model.embed_tokens.modules_to_save.default.weight"]:
        param.requires_grad = True

total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Total number of parameters: {total_params}')

In [None]:
train_path = "/mnt/data/openCTI/splitted-io-pairs/train"
validation_path = "/mnt/data/openCTI/splitted-io-pairs/validation"

In [None]:
formatted_train_list = []
formatted_eval_list = []

for file in os.listdir(train_path):
    cti_type = file.split("--")[0]
    if cti_type in ["relationship", "report"]:
        continue
    example = load_json(train_path, file)
    formatted_example = format_example(example, system_messages[cti_type])
    formatted_train_list.append(formatted_example)

for file in os.listdir(validation_path):
    cti_type = file.split("--")[0]
    if cti_type in ["relationship", "report"]:
        continue
    example = load_json(validation_path, file)
    formatted_example = format_example(example, system_messages[cti_type])
    formatted_eval_list.append(formatted_example)

In [None]:
# formatted_train_list = formatted_train_list[:5]
# formatted_eval_list = formatted_eval_list[:2]

In [None]:
import datasets

# Add template of the model in examples
templated_train_list = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in formatted_train_list]
templated_eval_list = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in formatted_eval_list]
# Create hf seperated datasets
hf_train = datasets.Dataset.from_list([dict(text=ex) for ex in templated_train_list])
hf_eval = datasets.Dataset.from_list([dict(text=ex) for ex in templated_eval_list])
# Create a hf dataset dict
dataset = datasets.DatasetDict({"train":hf_train, "eval":hf_eval})
# Filter dataset
if config["filter_dataset"]:
    if not config["filter_threshold"]:
        config["filter_threshold"] = tokenizer.model_max_length
    dataset = dataset.filter(lambda x: len(tokenizer.encode(x["text"])) <= config["filter_threshold"])

In [None]:
pprint(config["lora_parameters"])

In [None]:
# Add LoRA weights
model = FastLanguageModel.get_peft_model(
    model=model,
    **config["lora_parameters"]
)

In [None]:
_train_on_responses_only_bool = False
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)

In [None]:
config["training_arguments"]["output_dir"] = "Llama-3.1-8B-Instruct-CTI-Subtasks"
config["training_arguments"]["seed"] = 4321
config["lr_scheduler_type"] = "constant"

In [None]:
# Initiate trainer
trainer = UnslothTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset["train"],
    eval_dataset = dataset["eval"],
    data_collator = data_collator,
    dataset_text_field = "text",
    max_seq_length = config["model_loading_args"]["max_seq_length"], # Used only when packing=True for creating a ConstantLengthDataset.
    packing = config["sft_trainer_arguments"]["apply_packing"],
    dataset_num_proc = num_proc,
    args = UnslothTrainingArguments(
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        **config["training_arguments"]
    )
)

In [None]:
# Wrap trainer for apply training using only the assistant part
if _train_on_responses_only_bool:
    trainer = train_on_responses_only(
        trainer,
        instruction_part = config["instruction_part"],
        response_part = config["response_part"]
    )

In [None]:
config["early_stopping_patience"] = False

if config["early_stopping_patience"]:
    from transformers import EarlyStoppingCallback
    early_stopping_callback = EarlyStoppingCallback(early_stopping_patience = config["early_stopping_patience"])
    trainer.add_callback(early_stopping_callback)

In [None]:
# Start training
trainer_stats = unsloth_train(trainer, resume_from_checkpoint = True)

: 

In [None]:
save_log_history(trainer)

In [None]:
from transformers import TextStreamer

FastLanguageModel.for_inference(model)
text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

def format_input_prompt(system_message, user_input):
    formatted_input = [
        {"role": "assistant", "content": system_message},
        {"role": "user", "content": user_input}
    ]
    return formatted_input

def format_validation_example_for_inference(example):
    return example.split("<|start_header_id|>user<|end_header_id|>")[1].split("<|eot_id|><|start_header_id|>assistant<|end_header_id|>")[0]

def inference(model, system_message, user_input, max_new_tokens=None, **kwargs):
    input_ids = tokenizer.apply_chat_template(
        format_input_prompt(system_message, user_input),
        add_generation_prompt=True,
        return_tensors = "pt").to("cuda")
    if not max_new_tokens:
        max_new_tokens = model.config.max_position_embeddings - input_ids.shape[-1]
    model.generate(input_ids, streamer = text_streamer, max_new_tokens=max_new_tokens, **kwargs)

In [None]:
system_message = system_messages["domain-name"]
user_input = format_validation_example_for_inference(dataset["eval"]["text"][1364])
inference(model,
          system_message, 
          user_input, 
          max_new_tokens=None,
          temperature=0.7,
          top_p=0.6,
          repetition_penalty=1.1,
          no_repeat_ngram_size=3,
          do_sample=True)

In [None]:
print(dataset["eval"]["text"][1364])

In [None]:
"info@olymp.is" in dataset["eval"]["text"][134]