In [1]:
from datasets import load_dataset, Dataset
from trl import SFTConfig, SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from torch.utils.data import DataLoader
from peft import LoraConfig
import torch
import pandas as pd
import json
import re
import glob
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import classification_report, hamming_loss

treat the dataset, from raw log into conversational format (only need to run once)

In [2]:
dataset_path = glob.glob("../data/*.csv")
taxonomy_map = {
    "info": 1,
    "injection": 2,
    "traversal": 3,
    "rce": 4,
    "proxy": 5,
    "xss": 6,
    "llm": 7,
    "other": 8
}

In [3]:
def load_csv(dataset_path):
    df = pd.read_csv(dataset_path, skiprows=lambda x: x in range(1), names=['log', 'label', 'category', 'misc'])
    return df

In [4]:
df = pd.concat((load_csv(f) for f in dataset_path), ignore_index=True)
df = df.drop(df.columns[3], axis=1)
df["category"] = df["category"].apply(    
    lambda x: sorted([
        taxonomy_map[k.strip().lower()]
        for k in str(x).split(',')
        if k.strip().lower() in taxonomy_map
    ]) if pd.notna(x) else [0]
    )

In [5]:
def generate_prompt(log):
    user_prompt = "Given a log entry collected from an Apache HTTP server, classify it as either \"Malicious\" or \"Benign\".\n\n\
        If the log is classified as malicious, specify the reason(s) (can be multiple) by selecting from the following categories: \n\n\
        1. information exposure (reconaissance, scanning)\n \
        2. injection (including command injection, sql injection, XML external entity attack, shellcode injection)\n \
        3. path traversal\n\
        4. remote code execution\n\
        5. proxy-based attack (Server-Side Request Forgery, open redirect)\n \
        6. cross site scripting\n\
        7. prompt injection targeting LLM models\n\
        8. other (not mentioned above, e.g., local file inclusion, remote file inclusion, etc.)\n\n\
        Return your answer in strict JSON format for structured parsing. Use the following format:\n\n{{\n  \"classification\": \"Malicious or Benign\",\n  \"reason\": \"Comma-separated list of category numbers if malicious; leave empty if benign\"\n}}\n #### Explaination: why the weblog provided is malicious, leave empty if benign.\n\
        After the JSON briefly explain the reasoning for malicious classifications, if the log is benign, no explanation is needed."
    messages = [
        {"role": "system", "content": "You are a cybersecurity expert analyzing Apache log entries to detect potential security threats."},
        {"role": "user", "content": user_prompt + "\n\
        Log:" + log},
    ]
    return messages
def generate_response(label, category):
    if label == 0:
        return {"role": "assistant", "content": "```json {{ \n \"classification\" : \"Benign\", \n \"reason\":\"\"\n}}\n```"}
    else:
        return {"role": "assistant", "content": f"```json {{ \n \"classification\" : \"Malicious\", \n \"reason\":\"{str(category)}\"\n}}\n```"}

In [6]:
dicts = []
for _, row in df.iterrows():
    conversation = generate_prompt(row.iloc[0])
    conversation.append(generate_response(row.iloc[1], row.iloc[2]))
    entry = {"messages": conversation}
    dicts.append(entry)

with open("../data/prompt.json", "w", encoding="utf-8") as f:
    json.dump(dicts, f, indent=2, ensure_ascii=False)    

load dataset from json into huggingface Dataset format

In [7]:
dataset = load_dataset("json", data_files="../data/prompt.json", split='train')
dataset= dataset.train_test_split(test_size=0.2)

dataset

Generating train split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['messages'],
        num_rows: 2737
    })
    test: Dataset({
        features: ['messages'],
        num_rows: 685
    })
})

Training step: using LoRA with SFT trainer

In [8]:
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules="all-linear",
    modules_to_save=["lm_head", "embed_token"],
    task_type="CAUSAL_LM",
)

In [9]:
max_memory = {0: torch.cuda.get_device_properties(0).total_memory}
print(max_memory)

{0: 47725936640}


In [10]:
checkpoint='/rds/general/user/rm521/home/fyp/qwen2.5-7B'

model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16, max_memory=max_memory)
trainer = SFTTrainer(
    model,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    args=SFTConfig(
        output_dir="Qwen2.5-7B-SFT", 
        do_eval=True,
        per_device_train_batch_size=2,
    ),
    peft_config=peft_config,
)

Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]


[32mINFO[0m  ENV: Auto setting PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' for memory saving.


Converting train dataset to ChatML:   0%|          | 0/2737 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/2737 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/2737 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/2737 [00:00<?, ? examples/s]

Converting eval dataset to ChatML:   0%|          | 0/685 [00:00<?, ? examples/s]

Applying chat template to eval dataset:   0%|          | 0/685 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/685 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/685 [00:00<?, ? examples/s]

No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [11]:
trainer.train()

Step,Training Loss
500,0.2463
1000,0.1221
1500,0.1054
2000,0.0934
2500,0.0838
3000,0.0801
3500,0.0696
4000,0.0722


TrainOutput(global_step=4107, training_loss=0.10833801328047424, metrics={'train_runtime': 2741.6288, 'train_samples_per_second': 2.995, 'train_steps_per_second': 1.498, 'total_flos': 1.6382278783785062e+17, 'train_loss': 0.10833801328047424})

Evaluation step
if model is not loaded in VRAM, reload from file

In [12]:
torch.cuda.empty_cache()

In [13]:
model_name = '/rds/general/user/rm521/home/fyp/step3-sft/Qwen2.5-7B-SFT/checkpoint-4107'
sft_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16, max_memory=max_memory)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [14]:
def collate_fn(batch):
    # Each `batch[i]` is a dictionary with a "messages" field
    refs = [sample["messages"][2] for sample in batch]
    prompts = [
        tokenizer.apply_chat_template(
            sample["messages"][:2],  # assuming system + user
            tokenize=False,
            add_generation_prompt=True
        )
        for sample in batch
    ]
    tokenized = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        padding_side='left',
        truncation=True,
    )
    tokenized["ref"] = refs
    return tokenized

# Step 2: Create DataLoader over a sliced dataset
subset = dataset["test"]
loader = DataLoader(subset, batch_size=5, collate_fn=collate_fn)

In [15]:
import ast
def parse_label_string(s): # parsing labels such as [1,2,3], [4]
    s = s.strip()
    if not s:
        return [0]
    try:
        parsed = ast.literal_eval(s)
        if isinstance(parsed, int):
            return [parsed]
        elif isinstance(parsed, list):
            return [int(x) for x in parsed]
        else:
            return [0]
    except Exception:
        return [0]

def multi_label(response):
    for line in response.split("\n"):
        if "reason" in line:
            matched = True
            match = re.search(r'"reason":\s*"([^"]*)"', line)
            if match:
                reason_str = match.group(1)
                return parse_label_string(reason_str)
            else:
                return [0]
    if not matched:
        return [0]

In [16]:
eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")

total_cnt = 0
correct_cnt = 0

generated_responses = []
references = []

for batch in loader:
    refs = batch.pop("ref")
    batch = {k: v.to(sft_model.device) for k, v in batch.items()}
    
    with torch.no_grad():
        generated = sft_model.generate(
            **batch,
            max_new_tokens=512,
            eos_token_id=eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
    
    # Slice off prompt inputs to get only the generated completion
    generated_ids = [
        output[len(input_ids):]
        for input_ids, output in zip(batch["input_ids"], generated)
    ]
    
    responses = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    generated_responses += responses
    references += refs

In [20]:
resps = [multi_label(text) for text in generated_responses]
refs = [multi_label(item['content']) for item in references]

In [21]:
mlb = MultiLabelBinarizer()
y_true = mlb.fit_transform(refs)
y_pred = mlb.transform(resps)

In [22]:
class_labels = [str(label) for label in mlb.classes_]
print(class_labels)
print("Classification Report:\n", classification_report(y_true, y_pred, target_names=class_labels))
hamming = hamming_loss(y_true, y_pred)
print(f"Hamming Loss: {hamming}")

['0', '1', '2', '3', '4', '5', '6', '7', '8']
Classification Report:
               precision    recall  f1-score   support

           0       0.94      0.93      0.94       258
           1       0.94      0.94      0.94       340
           2       0.75      0.75      0.75        16
           3       0.91      0.84      0.87        37
           4       0.85      0.89      0.87        37
           5       0.71      0.86      0.77        14
           6       0.88      0.93      0.90        15
           7       0.78      0.78      0.78         9
           8       0.20      0.33      0.25         3

   micro avg       0.92      0.92      0.92       729
   macro avg       0.77      0.81      0.79       729
weighted avg       0.92      0.92      0.92       729
 samples avg       0.92      0.92      0.92       729

Hamming Loss: 0.01962692619626926
