# Finetuning as CausalML

**This use case is an extention of [Classification_cybersecurity_descriptions](https://github.com/RobustIntelligence/foundation-ai-cookbook/blob/main/2_examples/Classification_cybersecurity_descriptions.ipynb) shown in 2_examples**

For this demo, we use datasets of CTI-VSP (Cyber Threat Intelligence Vulnerability Severity Prediction) to predict items of the Common Vulnerability Scoring System (CVSS) vector. **The dataset is NOT used for training of Foundation-Sec-8B model.**

For example, given an description of 
```In the Linux kernel through 6.7.1, there is a use-after-free in cec_queue_msg_fh, related to drivers/media/cec/core/cec-adap.c and drivers/media/cec/core/cec-api.c.```, the model is asked to classify the description to correct labels based on the category in question. <br>
If the category is Attack Vector, choices are Network, Adjacent, Local or Physical, while if the category is Integrity Impact, choices are None, Low or High.


To see the details of datasets, refer to
- Paper: https://arxiv.org/pdf/2406.07599 <br>
- GitHub: https://github.com/xashru/cti-bench/tree/main

Note that we have modified the datasets to suit our use case. <br>
We'll finetune Foundation-Sec-8B as well as original llama model to show how finetuning works, and how Foundation-Sec-8B outperforms the original model.

### Hardware
This finetuning has been conducted under Nvidia 8xA100 (80GB) GPUs. Though it's doable with 1 GPU, it'll be slower. If you don't have enough memories, consider enabling QLoRa. That'll save memories at the cost of small performance degration.

# Setup

In [1]:
!export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True

In [2]:
import random
import numpy as np
import torch

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

import warnings
warnings.simplefilter('ignore')

In [3]:
DEVICE = "cuda:0"

print(f"CUDA Available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

CUDA Available: True
GPU: NVIDIA A100-SXM4-80GB


# Model Download & Test

In [4]:
import os

HF_TOKEN = os.environ.get("HF_TOKEN")
WB_PROJECT_NAME = "finetuning_demo"

LLAMA_MODEL_ID = "meta-llama/Llama-3.1-8B"
FOUNDATION_SEC_8B_MODEL_ID = "fdtn-ai/Foundation-Sec-8B"

In [5]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# tokenizer is the same for all processes

tokenizer = AutoTokenizer.from_pretrained(LLAMA_MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

In [6]:
def load_model(model_id):

    # Comment out below if you want to disable QLoRa for normal LoRa
    bnb_config = BitsAndBytesConfig(
        load_in_4bit = True,
        bnb_4bit_quant_type = "nf4",
        bnb_4bit_compute_dtype = "float16",
        bnb_4bit_use_double_quant = True
    )

    model = AutoModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path = model_id,
        device_map = DEVICE,
        quantization_config = bnb_config,
    ).to(DEVICE)
    model.config.use_cache = False
    model.config.pretraining_tp = 1
    model.generation_config.top_p = None
    model.generation_config.temperature = None
    model.generation_config.pad_token_id = tokenizer.eos_token_id    

    return model

In [7]:
splitter = "My choice:"
MAX_LENGTH = 1024
MAX_NEW_TOKEN = 3

def inference(prompt, model):
    inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        _output = model.generate(
            **inputs,
            max_new_tokens = MAX_NEW_TOKEN,
            do_sample = False,
            repetition_penalty = 1.2,
        )
    output = tokenizer.decode(_output[0], skip_special_tokens = True)
    response = output.split(splitter)[-1].strip()
    return response

Let's see how each model works with an example. 

Give a prompt to each model and see what the output looks like. <br>
The correct answer is **High**. Original llama failed to answer correctly, while Foundation-Sec-8B did successfully.

In [8]:
prompt = '''
I have a description about a threat intelligence analysis.

description: Cross Site Scripting vulnerability in the input parameter in eyoucms v.1.6.5 allows a remote attacker to run arbitrary code via crafted URL.
Regarding Integrity Impact I will answer only one of the following choices in 1 word: None, Low or High
My choice: Low

description: The EventON WordPress plugin before 2.2 does not sanitise and escape some of its settings, which could allow high privilege users such as admin to perform Stored HTML Injection attacks even when the unfiltered_html capability is disallowed.
Regarding Availability Impact I will answer only one of the following choices in 1 word: None, Low or High
My choice: None

description: A vulnerability, which was classified as critical, was found in Youke365 up to 1.5.3. Affected is an unknown function of the file /app/api/controller/caiji.php of the component Parameter Handler. The manipulation of the argument url leads to server-side request forgery. It is possible to launch the attack remotely. The exploit has been disclosed to the public and may be used. VDB-249870 is the identifier assigned to this vulnerability.
Regarding Attack Vector Impact I will answer only one of the following choices in 1 word: Network, Adjacent, Local or Physical
My choice: Network

description: ASQL injection vulnerability in EmpireCMS v7.5, allows remote attackers to execute arbitrary code and obtain sensitive information via the DoExecSql function.
Regarding Confidentiality Impact I will answer only one of the following choices in 1 word: Low or High
My choice: High

description: IBM WebSphere Application Server Liberty 17.0.0.3 through 24.0.0.4 is vulnerable to a denial of service, caused by sending a specially crafted request. A remote attacker could exploit this vulnerability to cause the server to consume memory resources.  IBM X-Force ID:  280400.
Regarding Privileges Required I will answer only one of the following choices in 1 word: None, Low or High
My choice: None

description: In vsp driver, there is a possible use after free due to a logic error. This could lead to local denial of service with System execution privileges needed
Regarding Privileges Required I will answer only one of the following choices in 1 word: None, Low or High
My choice:'''

In [9]:
llama_model = load_model(LLAMA_MODEL_ID)
llama_output_test = inference(prompt, llama_model)

#To avoid OOM error load model one by one and remove models not currently being used
import gc

llama_model = None
gc.collect()

print("Llama-3.1-8B's answer:", llama_output_test)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Llama-3.1-8B's answer: None


In [10]:
foundation_sec_8b_model = load_model(FOUNDATION_SEC_8B_MODEL_ID)
foundation_sec_8b_output_test = inference(prompt, foundation_sec_8b_model)

foundation_sec_8b_model = None
gc.collect()

print("Foundation-Sec-8B's answer:", foundation_sec_8b_output_test)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Foundation-Sec-8B's answer: High


# Data Preparation

Let's download datasets and pre-process them for evaluation and finetuning.


In [11]:
prompt_template = '''I have a description about a threat intelligence analysis.
description: {description}
Regarding {category} I will answer only one of the following choices in 1 word: {choices}
My choice:'''

finetuning_prompt_template = '''I have a description about a threat intelligence analysis.
description: {description}
Regarding {category} I will answer only one of the following choices in 1 word: {choices}
My choice:{label}'''

In [12]:
# Download data from https://github.com/xashru/cti-bench/blob/main/data/cti-vsp.tsv first
# Here it's assumed that cti-vsp is downloaded at cti-vsp directory
from pathlib import Path

PATH_TO_CTI_VSP = Path("cti-vsp")

Since each record has 8 categories to be classified, split each record to 8 pairs of description and label.

In [13]:
import pandas as pd
import re
from datasets import Dataset, load_from_disk

VULNERABILITY_CATEGORIES_SHORT = [
    "AV",
    "AC",
    "PR",
    "UI",
    "S",
    "C",
    "I",
    "A"
]
VULNERABILITY_CATEGORIES_LONG = [
    "Attack Vector",
    "Attack Complexity",
    "Privileges Required",
    "User Interaction",
    "Scope",
    "Confidentiality Impact",
    "Integrity Impact",
    "Availability Impact"
]
NOT_FOUND = "N/A"


def _extract_label(combined_labels_str):

    combined_labels_str = combined_labels_str.replace("CVSS:3.1", "")

    def _extract_label_from_vector(combined_labels_str, category):
        match = re.search(rf'/{category}:([^/]+)', combined_labels_str)
        if match:
            return _map_to_full_labels(category, match.group(1))
        return NOT_FOUND

    return [_extract_label_from_vector(combined_labels_str, category) for category in VULNERABILITY_CATEGORIES_SHORT]


# - Attack Vector (AV): Network (N), Adjacent (A), Local (L), Physical (P)
# - Attack Complexity (AC): Low (L), High (H)
# - Privileges Required (PR): None (N), Low (L), High (H)
# - User Interaction (UI): None (N), Required (R)
# - Scope (S): Unchanged (U), Changed (C)
# - Confidentiality (C): None (N), Low (L), High (H)
# - Integrity (I): None (N), Low (L), High (H)
# - Availability (A): None (N), Low (L), High (H)
def _map_to_full_labels(category, label):
    mapping = {
        "AV": {
            "N": "Network",
            "A": "Adjacent",
            "L": "Local",
            "P": "Physical",
        },
        "AC": {
            "L": "Low",
            "H": "High",
        },
        "PR": {
            "N": "None",
            "L": "Low",
            "H": "High",
        },
        "UI": {
            "N": "None",
            "R": "Required",
        },
        "S": {
            "U": "Unchanged",
            "C": "Changed",
        },
        "C": {
            "N": "None",
            "L": "Low",
            "H": "High",
        },
        "I": {
            "N": "None",
            "L": "Low",
            "H": "High",
        },
        "A": {
            "N": "None",
            "L": "Low",
            "H": "High",
        },
    }
    return mapping[category][label]


def extract_labels(tsv_path, output_csv_path):
    _df = pd.read_csv(tsv_path, sep='\t')
    rows = []
    descriptions = _df["Description"].to_list()
    combined_labels = _df["GT"].to_list() 
    for desc, cb in zip(descriptions, combined_labels):
        labels = _extract_label(cb)
        assert len(labels) == len(VULNERABILITY_CATEGORIES_LONG)
        for i in range(len(VULNERABILITY_CATEGORIES_LONG)):
            rows.append([VULNERABILITY_CATEGORIES_LONG[i], desc, labels[i]])
    df = pd.DataFrame(rows, columns=["category", "description", "label"])
    df = df.drop_duplicates().set_index("category")
    print("Dataset summary \n\n", df.groupby(["category", "label"]).count())    
    df.to_csv(output_csv_path, index=False)
    return df


def get_benchmark_by_majority_classifier(dataset):
    eval_datasets = dataset
    total_len = len(eval_datasets)
    df_eval_datasets = eval_datasets.to_pandas()
    df_group = df_eval_datasets.groupby(["category", "label"]).count()
    df_group = df_group.reset_index()
    df_group = df_group.set_index("category")
    df_group["category_total"] = df_group.groupby(df_group.index).sum()["description"]
    df_group["pct"] = df_group["description"] / df_group["category_total"]
    df_group["weighted_pct"] = df_group["pct"] * df_group["category_total"] / total_len
    accuracy = df_group.groupby(df_group.index).max().sum()["weighted_pct"]
    return accuracy


TSV_PATH = PATH_TO_CTI_VSP / "cti-vsp.tsv"
CSV_PATH = PATH_TO_CTI_VSP / "modified_cti-vsp.csv"
df = extract_labels(TSV_PATH, CSV_PATH)

Dataset summary 

                                   description
category               label                 
Attack Complexity      High                28
                       Low                960
Attack Vector          Adjacent            16
                       Local              187
                       Network            783
                       Physical             2
Availability Impact    High               560
                       Low                  9
                       None               419
Confidentiality Impact High               544
                       Low                278
                       None               166
Integrity Impact       High               482
                       Low                279
                       None               227
Privileges Required    High                83
                       Low                337
                       None               568
Scope                  Changed            252
               

Split the datasets into train and eval and save to disk. <br>
Let's also calculate benchmark of major class classifier for comparison with metrics by finetuned models later.

In [14]:
HF_DATASET_NAME = "hf_dataset"

SPLIT_SIZE = 0.2
SEED = 42
dataset = Dataset.from_pandas(df).train_test_split(test_size = SPLIT_SIZE, shuffle = True, seed = SEED)
dataset.save_to_disk(PATH_TO_CTI_VSP / HF_DATASET_NAME)
benchmark = get_benchmark_by_majority_classifier(dataset["test"])
print("\nBenchmark by majority classifier:", round(benchmark, 4))

Saving the dataset (0/1 shards):   0%|          | 0/6323 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1581 [00:00<?, ? examples/s]


Benchmark by majority classifier: 0.6629


In [15]:
VSP_MAPPING = {
    "Attack Vector": "Network, Adjacent, Local or Physical",
    "Attack Complexity": "Low or High",
    "Privileges Required": "None, Low or High",
    "User Interaction": "None or Required",
    "Scope": "Unchanged or Changed",
    "Confidentiality Impact": "None, Low or High",
    "Integrity Impact": "None, Low or High",
    "Availability Impact": "None, Low or High",
}


def load_dataset_and_preprocess():

    remove_columns = ["category", "description", "label"]

    def _preprocess_data(examples):

        prompts = [
            finetuning_prompt_template.format(
                description = description[:MAX_LENGTH - 100], #truncate description to avoid exceeding max_length,
                category = category,
                label = label,
                choices = VSP_MAPPING[category],
            )
            for description, category, label in zip(
                examples["description"],
                examples["category"],
                examples["label"],
            )
        ]
        
        return tokenizer(prompts, truncation = True, padding = "max_length", max_length = MAX_LENGTH)

    hf_datasets = load_from_disk(PATH_TO_CTI_VSP /  HF_DATASET_NAME)
    train_data = hf_datasets["train"]
    test_data = hf_datasets["test"]
    tokenized_train = train_data.map(_preprocess_data, batched = True, remove_columns = remove_columns)
    tokenized_test = test_data.map(_preprocess_data, batched = True, remove_columns = remove_columns)
    print(f"Train samples: {len(tokenized_train)}, Test samples: {len(tokenized_test)}")

    return tokenized_train, tokenized_test

In [16]:
tokenized_train, tokenized_test = load_dataset_and_preprocess()

Train samples: 6323, Test samples: 1581


# Evaluation (before finetuning)

Let's see how models perform before finetuing is conducted.

In [17]:
def get_prompts_and_labels():
    hf_dataset = load_from_disk(PATH_TO_CTI_VSP / HF_DATASET_NAME)
    df = hf_dataset["test"].to_pandas()
    prompts = []
    labels = []
    for row in df.iterrows():
        row = row[1]
        category = row['category']
        choices = VSP_MAPPING[category]
        description = row['description']
        label = row['label']
        prompt = prompt_template.format(description=description, category=category, choices=choices)
        prompts.append(prompt)
        labels.append(label)
    return prompts, labels

In [18]:
def evaluate_pred(prompts, labels, model):
    preds = [inference(prompt, model) for prompt in prompts]
    num_exist = sum(1 for label, pred in zip(labels, preds) if str(label).lower() in str(pred).lower())
    print(f"{num_exist} out of total {len(labels)}")
    return round(num_exist/len(labels), 4)

def eval(model_id):
    model = load_model(model_id)
    prompts, labels = get_prompts_and_labels()
    result = evaluate_pred(prompts, labels, model)
    print(f"Accuracy: {result}")

In [19]:
eval(LLAMA_MODEL_ID)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

674 out of total 1581
Accuracy: 0.4263


In [20]:
eval(FOUNDATION_SEC_8B_MODEL_ID)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

890 out of total 1581
Accuracy: 0.5629


It shows that Foundation-Sec-8B already outperforms the original model.

# Finetuning & Evaluation

Let's finetune the models using QLoRa approach

In [21]:
from transformers import TrainingArguments
from peft import LoraConfig, get_peft_model, PeftConfig
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM

OUTPUT_DIR = "./checkpoints"

def train(model_id):

    _output_dir = Path(OUTPUT_DIR) / str(f"{model_id}").replace("/", "_")

    model = load_model(model_id)

    training_args = TrainingArguments(
        output_dir = _output_dir,
        label_names = ["labels"],
        run_name = "finetuning_demo",
        num_train_epochs = 10,
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 8,
        eval_strategy = "no",
        logging_steps = 50,
        learning_rate = 4.e-5,
        weight_decay = 0.001,
        fp16 = False,
        bf16 = False,
        max_grad_norm = 0.3,
        max_steps = -1,
        group_by_length = True,
        lr_scheduler_type = "constant",
        seed = SEED,
        report_to = ["none"],
    )
    
    peft_parameters = LoraConfig(
        lora_alpha = 8,
        lora_dropout = 0.1,
        r = 8,
        bias = "none",
        task_type = "CAUSAL_LM",
    )
    peft_model = get_peft_model(model, peft_parameters)
    peft_model.print_trainable_parameters()

    data_collator = DataCollatorForCompletionOnlyLM(response_template = "My choice", tokenizer = tokenizer)

    trainer = SFTTrainer(
        model = peft_model,
        train_dataset = tokenized_train,
        eval_dataset = tokenized_test,
        peft_config = peft_parameters,
        args = training_args,
        data_collator = data_collator,
    )

    trainer.train()

In [22]:
from transformers.trainer_utils import get_last_checkpoint
from peft import PeftModel

def load_finetuned_model(original_model_id):
    _dir = Path(OUTPUT_DIR) / str(f"{original_model_id}").replace("/", "_")
    last_checkpoint = get_last_checkpoint(_dir)
    print("last_checkpoint:", last_checkpoint)
    peft_config = PeftConfig.from_pretrained(last_checkpoint)
    orginal_model = AutoModelForCausalLM.from_pretrained(
        original_model_id,
        torch_dtype = torch.float16,
        device_map = DEVICE,
    )
    peft_model = PeftModel.from_pretrained(orginal_model, last_checkpoint, is_trainable=True)
    model = peft_model.merge_and_unload()
    model.generation_config.top_p = None
    model.generation_config.temperature = None
    model.generation_config.pad_token_id = tokenizer.eos_token_id
    return model


def eval_finetuning(original_model_id):
    model = load_finetuned_model(original_model_id)
    prompts, labels = get_prompts_and_labels()
    result = evaluate_pred(prompts, labels, model)
    print(f"metrics: {result}")

In [23]:
train(LLAMA_MODEL_ID)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

trainable params: 3,407,872 || all params: 8,033,669,120 || trainable%: 0.0424


Step,Training Loss
50,0.2902
100,0.0469
150,0.034
200,0.0332
250,0.035
300,0.0325
350,0.0255
400,0.0259
450,0.025
500,0.0231


In [24]:
eval_finetuning(LLAMA_MODEL_ID)

last_checkpoint: checkpoints/meta-llama_Llama-3.1-8B/checkpoint-980


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

1383 out of total 1581
metrics: 0.8748


In [25]:
train(FOUNDATION_SEC_8B_MODEL_ID)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

trainable params: 3,407,872 || all params: 8,034,717,696 || trainable%: 0.0424


Step,Training Loss
50,0.2027
100,0.0421
150,0.031
200,0.0305
250,0.0296
300,0.0261
350,0.0236
400,0.0243
450,0.0226
500,0.0214


In [26]:
eval_finetuning(FOUNDATION_SEC_8B_MODEL_ID)

last_checkpoint: checkpoints/fdtn-ai_Foundation-Sec-8B/checkpoint-980


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

1446 out of total 1581
metrics: 0.9146


Both of performances of the original model and Foundation-Sec-8B have improved, and finetuned Foundation-Sec-8B still outperforms the finetuned llama.