In [1]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    set_seed,
    DataCollatorForLanguageModeling,
)
import evaluate
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
import torch
import pandas as pd
import numpy as np
from dotenv import load_dotenv
from huggingface_hub import login
import os

pd.set_option('display.max_colwidth', None)

In [2]:
load_dotenv()

RANDOM_SEED = int(os.getenv("RANDOM_SEED"))
HUGGINGFACE_TOKEN = os.getenv("HUGGINGFACE_TOKEN")

set_seed(RANDOM_SEED)
login(token=HUGGINGFACE_TOKEN)

In [3]:
DEVICE = torch.cuda.get_device_name(torch.cuda.current_device())

print(f"""Device Used = {DEVICE}""")

Device Used = NVIDIA GeForce RTX 4060


In [4]:
compute_dtype = getattr(torch, "float16")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=False,
)

In [5]:
MODEL_NAME = os.getenv("HUGGINGFACE_MODEL_NAME")

print(f"""Model Used = {MODEL_NAME}""")

Model Used = google/gemma-2-2b-it


In [6]:
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    padding_side="left",
    add_eos_token=True,
    add_bos_token=True,
    use_fast=False,
)
tokenizer.pad_token = tokenizer.eos_token

In [7]:
MAX_TOKENS = int(os.getenv("MAX_TOKENS"))
TEMPERATURE = float(os.getenv("TEMPERATURE"))
TOP_P = float(os.getenv("TOP_P"))
REPETITION_PENALTY = float(os.getenv("REPETITION_PENALTY"))

print(f"""Maximum Tokens = {MAX_TOKENS}
Temperature = {TEMPERATURE}
Top_P = {TOP_P}
Repetition Penalty = {REPETITION_PENALTY}""")

Maximum Tokens = 512
Temperature = 0.2
Top_P = 0.8
Repetition Penalty = 1.3


In [8]:
def get_chat_component(role, text):
    return {"role": role, "content": text}

def get_user_chat_component(text):
    return get_chat_component("user", text)

def get_assistant_chat_component(text):
    return get_chat_component("assistant", text)

def generate(model, prompt, 
             max_new_tokens=MAX_TOKENS, 
             temperature=TEMPERATURE, 
             top_p=TOP_P, 
             tokenizer=tokenizer, 
             repetition_penalty=REPETITION_PENALTY):
    
    """
    Generate text using the provided model and prompt
    
    Args:
        model: The language model
        prompt: Input text prompt
        max_new_tokens: Maximum length of generated text
        tokenizer: Pre-instantiated tokenizer (uses global tokenizer by default)
    
    Returns:
        List containing generated text
    """
    seperator = 200 * '-'
    
    # Tokenize the input
    formatted_prompt = tokenizer.apply_chat_template(
        [get_user_chat_component(prompt)],
        tokenize=False, 
        add_generation_prompt=False,
    )
    print(f'{seperator}\nINPUT [{len(formatted_prompt)}]:\n{formatted_prompt}\n{seperator}')

    inputs = tokenizer(
        formatted_prompt,
        return_tensors="pt", 
        truncation=True, 
        padding=True, 
    )
    
    # Move to same device as model
    device = next(model.parameters()).device
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Generate text
    with torch.no_grad():
        outputs = model.generate(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            num_return_sequences=1,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
            max_new_tokens=max_new_tokens,
            use_cache=False,
            repetition_penalty=repetition_penalty,
        )
    
    # Decode the generated text
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
    
    print(f'{seperator}\nOUTPUT [{len(generated_text)}]:\n{generated_text}\n{seperator}')

    return generated_text

## Data Preparation

In [9]:
GENERATED_DATA_DIR = os.getenv('GENERATED_DATA_DIR')
TRAINING_PROPORTION = float(os.getenv("TRAINING_PROPORTION"))
VALIDATION_PROPORTION = float(os.getenv("VALIDATION_PROPORTION"))
TESTING_PROPORTION = float(os.getenv("TESTING_PROPORTION"))

print(f"""Generated Data Directory = {GENERATED_DATA_DIR}
      
Training Proportion = {TRAINING_PROPORTION}
Validation Proportion = {VALIDATION_PROPORTION}
Testing Proportion = {TESTING_PROPORTION}""")

Generated Data Directory = data_generation/generated_data

Training Proportion = 0.7
Validation Proportion = 0.2
Testing Proportion = 0.1


In [10]:
def format_text(entry):
    entry['text'] = f'Instruct: {entry["input"]}\nOutput: {entry["output"]}'

def compile_dataset(source_dir, 
                    training_proportion=TRAINING_PROPORTION, 
                    validation_proportion=VALIDATION_PROPORTION):
    
    training_df = pd.DataFrame()
    validation_df = pd.DataFrame()
    testing_df = pd.DataFrame()

    for source_csv_dir in os.listdir(source_dir):
        source_csv_paths = os.listdir(os.path.join(source_dir, source_csv_dir))
        for csv_path in source_csv_paths:
            csv = pd.read_csv(os.path.join(source_dir, source_csv_dir, csv_path))
            csv = csv.sample(frac=1, random_state=RANDOM_SEED).reset_index(drop=True)
            csv['text'] = csv.apply(lambda entry: tokenizer.apply_chat_template([
                get_user_chat_component(entry["input"]),
                get_assistant_chat_component(entry["output"]),
            ], tokenize=False, add_generation_prompt=False), 1)
            total_qty = len(csv)
            validation_start = int(total_qty * training_proportion)
            validation_end = validation_start + int(total_qty * validation_proportion)
            training_df = pd.concat([training_df, csv.iloc[:validation_start]])
            validation_df = pd.concat([validation_df, csv.iloc[validation_start:validation_end]])
            testing_df = pd.concat([testing_df, csv.iloc[validation_end:]])

    training_df = training_df.sample(frac=1, random_state=RANDOM_SEED).reset_index(drop=True)
    validation_df = validation_df.sample(frac=1, random_state=RANDOM_SEED).reset_index(drop=True)
    testing_df = testing_df.sample(frac=1, random_state=RANDOM_SEED).reset_index(drop=True)

    return training_df, validation_df, testing_df

In [11]:
training_data, validation_data, testing_data = compile_dataset(GENERATED_DATA_DIR)

In [12]:
testing_data

Unnamed: 0,input,output,text
0,Describe the delivery stage of the Cyber Kill Chain.,"The delivery stage, the third in the Cyber Kill Chain, involves delivering the weaponized bundle to the victim through various means such as email, web, or USB drives.","<bos><start_of_turn>user\nDescribe the delivery stage of the Cyber Kill Chain.<end_of_turn>\n<start_of_turn>model\nThe delivery stage, the third in the Cyber Kill Chain, involves delivering the weaponized bundle to the victim through various means such as email, web, or USB drives.<end_of_turn>\n"
1,Provide the improved length checking condition and explain why it's better.,"The improved length checking condition is: `if (len1 <= sizeof(buf) && len2 <= sizeof(buf) && (len1 + len2 + 1 <= sizeof(buf)))`. \nThis condition is better because it addresses the `integer overflow` vulnerability more comprehensively. The original check `if (len1 + len2 + 1 <= sizeof(buf))` could be bypassed if `len1 + len2 + 1` overflowed and wrapped around to a small number, even if `len1` or `len2` individually were very large. The improved condition adds two preliminary checks: \n1. `len1 <= sizeof(buf)`: Ensures that the first length itself does not exceed the buffer's capacity. \n2. `len2 <= sizeof(buf)`: Ensures that the second length itself does not exceed the buffer's capacity. \nBy validating individual lengths first, it prevents scenarios where a single extremely large input (like `UINT_MAX`) could cause an overflow in the sum, leading to a false positive in the length check and a subsequent `Buffer Overflow`. The final check `(len1 + len2 + 1 <= sizeof(buf))` then safely verifies the combined length after the individual components have been deemed reasonable.","<bos><start_of_turn>user\nProvide the improved length checking condition and explain why it's better.<end_of_turn>\n<start_of_turn>model\nThe improved length checking condition is: `if (len1 <= sizeof(buf) && len2 <= sizeof(buf) && (len1 + len2 + 1 <= sizeof(buf)))`. \nThis condition is better because it addresses the `integer overflow` vulnerability more comprehensively. The original check `if (len1 + len2 + 1 <= sizeof(buf))` could be bypassed if `len1 + len2 + 1` overflowed and wrapped around to a small number, even if `len1` or `len2` individually were very large. The improved condition adds two preliminary checks: \n1. `len1 <= sizeof(buf)`: Ensures that the first length itself does not exceed the buffer's capacity. \n2. `len2 <= sizeof(buf)`: Ensures that the second length itself does not exceed the buffer's capacity. \nBy validating individual lengths first, it prevents scenarios where a single extremely large input (like `UINT_MAX`) could cause an overflow in the sum, leading to a false positive in the length check and a subsequent `Buffer Overflow`. The final check `(len1 + len2 + 1 <= sizeof(buf))` then safely verifies the combined length after the individual components have been deemed reasonable.<end_of_turn>\n"
2,Provide an example of how a criminal might impersonate a user's biometric data.,"For example, a criminal might impersonate a user's biometric data by creating a 'copy' of a user's fingerprint. This could be achieved by re-creating the fingerprint with a wax imprint that the criminal then places on top of their own finger to deceive the biometric scanner.","<bos><start_of_turn>user\nProvide an example of how a criminal might impersonate a user's biometric data.<end_of_turn>\n<start_of_turn>model\nFor example, a criminal might impersonate a user's biometric data by creating a 'copy' of a user's fingerprint. This could be achieved by re-creating the fingerprint with a wax imprint that the criminal then places on top of their own finger to deceive the biometric scanner.<end_of_turn>\n"
3,What are 'measures' put in place to prevent further queries?,Measures put in place to prevent further queries are security actions or configurations implemented to block or restrict unauthorized access attempts to the database.,<bos><start_of_turn>user\nWhat are 'measures' put in place to prevent further queries?<end_of_turn>\n<start_of_turn>model\nMeasures put in place to prevent further queries are security actions or configurations implemented to block or restrict unauthorized access attempts to the database.<end_of_turn>\n
4,What technologies should be considered for internet access strategy?,"Technologies to consider for an internet access strategy include internet surfing separation and internet isolation technology, with an assessment of their pros and cons.","<bos><start_of_turn>user\nWhat technologies should be considered for internet access strategy?<end_of_turn>\n<start_of_turn>model\nTechnologies to consider for an internet access strategy include internet surfing separation and internet isolation technology, with an assessment of their pros and cons.<end_of_turn>\n"
...,...,...,...
5523,Compare the scope of control gained by the attacker over Workstation A initially with the full control achieved over the domain later.,"Initially, with the introduction of the hacking tool and RAT 1 in December 2017, the attacker gained the capability to execute shell scripts remotely and upload/download files to Workstation A, indicating control over that specific workstation. Later, through privilege escalation and lateral movement, the attacker compromised the Windows authentication system and obtained administrator and user credentials, leading to full control over all Windows-based servers, hosted applications, all employee workstations, and underlying data within the domain. **Workstation A** refers to a single compromised computer, whereas **the domain** encompasses a broader network of interconnected computers, servers, and resources under a common security authority.","<bos><start_of_turn>user\nCompare the scope of control gained by the attacker over Workstation A initially with the full control achieved over the domain later.<end_of_turn>\n<start_of_turn>model\nInitially, with the introduction of the hacking tool and RAT 1 in December 2017, the attacker gained the capability to execute shell scripts remotely and upload/download files to Workstation A, indicating control over that specific workstation. Later, through privilege escalation and lateral movement, the attacker compromised the Windows authentication system and obtained administrator and user credentials, leading to full control over all Windows-based servers, hosted applications, all employee workstations, and underlying data within the domain. **Workstation A** refers to a single compromised computer, whereas **the domain** encompasses a broader network of interconnected computers, servers, and resources under a common security authority.<end_of_turn>\n"
5524,What does a pre-computed hash table contain?,"A pre-computed hash table contains pairs of passwords and their corresponding hash values. Essentially, it's a lookup table where an attacker can find a password if they have its hash, by matching the acquired hash with one of the pre-computed hashes in the table.","<bos><start_of_turn>user\nWhat does a pre-computed hash table contain?<end_of_turn>\n<start_of_turn>model\nA pre-computed hash table contains pairs of passwords and their corresponding hash values. Essentially, it's a lookup table where an attacker can find a password if they have its hash, by matching the acquired hash with one of the pre-computed hashes in the table.<end_of_turn>\n"
5525,How many possible passwords are there for a 5-character lowercase password?,"For a 5-character lowercase password, there are 26 possible choices for each character (the 26 lowercase alphabets). Therefore, the total number of possible passwords is 26 raised to the power of 5 (26^5).","<bos><start_of_turn>user\nHow many possible passwords are there for a 5-character lowercase password?<end_of_turn>\n<start_of_turn>model\nFor a 5-character lowercase password, there are 26 possible choices for each character (the 26 lowercase alphabets). Therefore, the total number of possible passwords is 26 raised to the power of 5 (26^5).<end_of_turn>\n"
5526,"Compare passphrases with traditional, shorter passwords in terms of memorability and security.","A 'passphrase' involves users entering sentences or long phrases that are generally easier to remember than complex, shorter passwords. The system applies a hash function to compute the fixed-size actual passwords from these passphrases. While traditional short passwords often rely on complexity (mix of characters), passphrases leverage length and natural language for both memorability and enhanced security against brute-force attacks due to their extended length.","<bos><start_of_turn>user\nCompare passphrases with traditional, shorter passwords in terms of memorability and security.<end_of_turn>\n<start_of_turn>model\nA 'passphrase' involves users entering sentences or long phrases that are generally easier to remember than complex, shorter passwords. The system applies a hash function to compute the fixed-size actual passwords from these passphrases. While traditional short passwords often rely on complexity (mix of characters), passphrases leverage length and natural language for both memorability and enhanced security against brute-force attacks due to their extended length.<end_of_turn>\n"


In [13]:
TRAINING_ENTRIES = len(training_data)
VALIDATION_ENTRIES = len(validation_data)
TESTING_ENTRIES = len(testing_data)

print(f"""Training Proportion = {TRAINING_PROPORTION}
Training Entries = {TRAINING_ENTRIES}

Validation Proportion = {VALIDATION_PROPORTION}
Validation Entries = {VALIDATION_ENTRIES}

Testing Proportion = {TESTING_PROPORTION}
Testing Entries = {TESTING_ENTRIES}""")

Training Proportion = 0.7
Training Entries = 34757

Validation Proportion = 0.2
Validation Entries = 9735

Testing Proportion = 0.1
Testing Entries = 5528


In [14]:
TRAINING_OUTPUT_DIR = os.getenv('TRAINING_OUTPUT_DIR')

print(f"""Training Output Directory = {TRAINING_OUTPUT_DIR}""")

Training Output Directory = model_training/results


## Model Evaluation

In [15]:
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, 
    device_map="auto",
    quantization_config=bnb_config,
    trust_remote_code=True,
    token=True,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [16]:
prompt = "How does an attacker use gadgets in Return-Oriented Programming?"
response = generate(base_model, prompt)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
INPUT [103]:
<bos><start_of_turn>user
How does an attacker use gadgets in Return-Oriented Programming?<end_of_turn>

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
OUTPUT [2964]:
<bos><bos><start_of_turn>user
How does an attacker use gadgets in Return-Oriented Programming?<end_of_turn>
<eos>

**Return-oriented programming (ROP) is a technique that exploits vulnerabilities by manipulating the program's execution flow to achieve unauthorized 

In [17]:
prompt = "How does TPM ensure that data can only be decrypted on the correct platform?"
response = generate(base_model, prompt)

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
INPUT [115]:
<bos><start_of_turn>user
How does TPM ensure that data can only be decrypted on the correct platform?<end_of_turn>

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
OUTPUT [2313]:
<bos><bos><start_of_turn>user
How does TPM ensure that data can only be decrypted on the correct platform?<end_of_turn>
<eos>

TPM (Trusted Platform Module) is a hardware security module built into many computers. It's designed to provide secure storage

In [None]:
original_model_outputs = []

for index in testing_data.index:
    original_model_outputs.append(generate(base_model, testing_data['input'][index]))

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
INPUT [91]:
<bos><start_of_turn>user
Describe the delivery stage of the Cyber Kill Chain.<end_of_turn>

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
OUTPUT [2844]:
<bos><bos><start_of_turn>user
Describe the delivery stage of the Cyber Kill Chain.<end_of_turn>
<eos>

The **CyberKillChain** is a framework for understanding and analyzing cyberattacks, developed by MITRE Corporation in 2014 to help organizations understand how attackers opera

In [None]:
SELECTED_STEP = int(os.getenv("SELECTED_STEP"))

print(f"""Selected Step = {SELECTED_STEP}""")

In [None]:
tuned_model = PeftModel.from_pretrained(base_model, 
                                        f"{TRAINING_OUTPUT_DIR}/checkpoint-{SELECTED_STEP}",
                                        torch_dtype=torch.float16,
                                        is_trainable=False)

In [None]:
prompt = "How does an attacker use gadgets in Return-Oriented Programming?"
response = generate(tuned_model, prompt)

In [None]:
prompt = "How does TPM ensure that data can only be decrypted on the correct platform?"
response = generate(tuned_model, prompt)

In [None]:
tuned_model_outputs = []

for index in testing_data.index:
    tuned_model_outputs.append(generate(tuned_model, testing_data['input'][index]))

In [None]:
rouge = evaluate.load('rouge')

original_model_results = rouge.compute(
    predictions=original_model_outputs,
    references=testing_data['output'].values,
    use_aggregator=True,
    use_stemmer=True,
)

peft_model_results = rouge.compute(
    predictions=tuned_model_outputs,
    references=testing_data['output'].values,
    use_aggregator=True,
    use_stemmer=True,
)

print(f'Original Model:\n{original_model_results}\n')
print(f'Tuned Model:\n{peft_model_results}\n')

In [None]:
print("Percentage ImprovementsL\n")
improvement = (np.array(list(peft_model_results.values())) - np.array(list(original_model_results.values())))
for key, value in zip(peft_model_results.keys(), improvement):
    print(f'{key}: {value*100:.2f}%')