In [1]:
from torch.utils.data import DataLoader
from tqdm import tqdm

def train_lm_critic(model, dataloader, optimizer, criterion, device="cpu", epochs=10):
    model.train()
    total_epoch_loss = 0
    for epoch in range(epochs):
        total_loss = 0
        pbar = tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}")
        for batch_idx, batch in enumerate(pbar):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)
            # print(f"labels: {labels} {labels.shape} {labels.dtype}")
            # Forward pass
            logits = model({"input_ids": input_ids, "attention_mask": attention_mask})
            loss = criterion(logits, labels)

            # Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_epoch_loss += loss.item()
            running_loss = total_epoch_loss / (batch_idx + 1)
            total_loss += loss.item()
            pbar.set_postfix(loss=f"{running_loss:.4f}")
        print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}")

def validate_lm_critic(model, dataloader, criterion, device="cpu"):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        pbar = tqdm(dataloader, desc="Validating")
        for batch_idx, batch in enumerate(pbar):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            # Forward pass
            logits = model({"input_ids": input_ids, "attention_mask": attention_mask})
            loss = criterion(logits, labels)

            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            running_loss = total_loss / (batch_idx + 1)
            pbar.set_postfix(loss=f"{running_loss:.4f}", accuracy=f"{correct / total:.4f}")

    accuracy = correct / total
    print(f"Validation Loss: {total_loss:.4f}, Accuracy: {accuracy:.4f}")
    return total_loss, accuracy


In [2]:
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

class LMCritic(nn.Module):
    def __init__(self, lm_model_name, lm_hidden_dim, critic_hidden_dim, output_dim):
        """
        Language Model Critic for syscall sequence classification.
        
        Args:
            lm_model_name (str): Pretrained language model name (e.g., "distilbert-base-uncased").
            lm_hidden_dim (int): Hidden dimension of the LM's output.
            critic_hidden_dim (int): Hidden dimension of the critic module.
            output_dim (int): Number of output classes (e.g., 2 for binary classification).
        """
        super(LMCritic, self).__init__()
        self.lm = AutoModel.from_pretrained(lm_model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(lm_model_name)
        self.critic = nn.Sequential(
            nn.Linear(lm_hidden_dim, critic_hidden_dim),
            nn.ReLU(),
            nn.Linear(critic_hidden_dim, output_dim)
        )

    def forward(self, sequences):
        """
        Forward pass for the LM-Critic model.
        
        Args:
            sequences (torch.Tensor): Batch of syscall sequences (tokenized input).
        
        Returns:
            logits (torch.Tensor): Model predictions for each sequence.
        """
        # Pass through the language model
        lm_outputs = self.lm(**sequences)
        cls_embedding = lm_outputs.last_hidden_state[:, 0, :]  # Use [CLS] embedding

        # Pass through the critic module
        logits = self.critic(cls_embedding)
        return logits


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
!head syscalls_11222024_1203AM.log
!head syscalls.txt

Attaching 357 probes...
[11156347] PID: 1605, Comm: pipewire-media-, Syscall: tracepoint:syscalls:sys_enter_epoll_wait
[11156348] PID: 1605, Comm: pipewire-media-, Syscall: tracepoint:syscalls:sys_enter_epoll_wait
[11156348] PID: 1605, Comm: pipewire-media-, Syscall: tracepoint:syscalls:sys_enter_epoll_wait
[11156348] PID: 1605, Comm: pipewire-media-, Syscall: tracepoint:syscalls:sys_enter_epoll_wait
[11156348] PID: 1605, Comm: pipewire-media-, Syscall: tracepoint:syscalls:sys_enter_epoll_wait
[11156349] PID: 1605, Comm: pipewire-media-, Syscall: tracepoint:syscalls:sys_enter_epoll_wait
[11156349] PID: 1605, Comm: pipewire-media-, Syscall: tracepoint:syscalls:sys_enter_epoll_wait
[11156349] PID: 1605, Comm: pipewire-media-, Syscall: tracepoint:syscalls:sys_enter_epoll_wait
[11156350] PID: 1605, Comm: pipewire-media-, Syscall: tracepoint:syscalls:sys_enter_epoll_wait
0	read
1	write
2	open
3	close
4	stat
5	fstat
6	lstat
7	poll
8	lseek
9	mmap


In [4]:
f = open('syscalls.txt', 'r')
x86_64_syscalls = f.read()
f.close()

def parse_syscall_table(table):
    syscall_vocab = {}
    for i, line in enumerate(table.strip().split("\n")):
        parts = line.split("\t")
        index = int(parts[0])  # Syscall index
        name = parts[1]        # Syscall name
        syscall_vocab[name] = i
    return syscall_vocab

syscall_vocab = parse_syscall_table(x86_64_syscalls)
print(f"syscall vocab size is {len(syscall_vocab)}")
print(syscall_vocab)
from os import times
import pandas as pd
import itertools
from tqdm import tqdm

def preprocess_flow_file(file_path):
    data = []

    with open(file_path, 'r') as file:
        total_lines = sum(1 for _ in file) - 2  # Subtract 2 for skipped lines

    with open(file_path, 'r') as file:
        next(file)

        for line in tqdm(itertools.islice(file, None, None, None), total=total_lines, desc="Processing lines"):
            try:
                parts = line.strip().split(", ")
                timestamp, _, pid = parts[0].split(" ")
                timestamp = int(timestamp.strip("[]"))
                pid = int(pid)
                comm = parts[1].split(": ")[1]
                syscall = parts[2].split(": ")[1]
                data.append({"Timestamp": timestamp, "PID": pid, "Comm": comm, "Syscall": syscall})
            except Exception as e:
                print(f"Error processing line: {line.strip()}")
                print(f"Error: {e}")

    df = pd.DataFrame(data)
    return df

flow_file_path = "syscalls_11222024_1203AM.log"
flow_data = preprocess_flow_file(flow_file_path)
flow_data.head()

syscall vocab size is 361
{'read': 0, 'write': 1, 'open': 2, 'close': 3, 'stat': 4, 'fstat': 5, 'lstat': 6, 'poll': 7, 'lseek': 8, 'mmap': 9, 'mprotect': 10, 'munmap': 11, 'brk': 12, 'rt_sigaction': 13, 'rt_sigprocmask': 14, 'rt_sigreturn': 15, 'ioctl': 16, 'pread': 17, 'pwrite': 18, 'readv': 19, 'writev': 20, 'access': 21, 'pipe': 22, 'select': 23, 'sched_yield': 24, 'mremap': 25, 'msync': 26, 'mincore': 27, 'madvise': 28, 'shmget': 29, 'shmat': 30, 'shmctl': 31, 'dup': 32, 'dup2': 33, 'pause': 34, 'nanosleep': 35, 'getitimer': 36, 'alarm': 37, 'setitimer': 38, 'getpid': 39, 'sendfile': 40, 'socket': 41, 'connect': 42, 'accept': 43, 'sendto': 44, 'recvfrom': 45, 'sendmsg': 46, 'recvmsg': 47, 'shutdown': 48, 'bind': 49, 'listen': 50, 'getsockname': 51, 'getpeername': 52, 'socketpair': 53, 'setsockopt': 54, 'getsockopt': 55, 'clone': 56, 'fork': 57, 'vfork': 58, 'execve': 59, 'exit': 60, 'wait4': 61, 'kill': 62, 'uname': 63, 'semget': 64, 'semop': 65, 'semctl': 66, 'shmdt': 67, 'msgget'

Processing lines: 11264394it [00:11, 1000497.04it/s]                              


Error processing line: @self: 3167
Error: not enough values to unpack (expected 3, got 2)


Unnamed: 0,Timestamp,PID,Comm,Syscall
0,11156347,1605,pipewire-media-,tracepoint:syscalls:sys_enter_epoll_wait
1,11156348,1605,pipewire-media-,tracepoint:syscalls:sys_enter_epoll_wait
2,11156348,1605,pipewire-media-,tracepoint:syscalls:sys_enter_epoll_wait
3,11156348,1605,pipewire-media-,tracepoint:syscalls:sys_enter_epoll_wait
4,11156348,1605,pipewire-media-,tracepoint:syscalls:sys_enter_epoll_wait


In [5]:
flow_data.tail()

Unnamed: 0,Timestamp,PID,Comm,Syscall
11264388,21079220,829,avahi-daemon,tracepoint:syscalls:sys_enter_write
11264389,21079231,833,NetworkManager,tracepoint:syscalls:sys_enter_write
11264390,21079231,833,NetworkManager,tracepoint:syscalls:sys_enter_write
11264391,21079231,833,NetworkManager,tracepoint:syscalls:sys_enter_write
11264392,21079231,833,NetworkManager,tracepoint:syscalls:sys_enter_write


In [6]:
flow_data["Timestamp"] = flow_data["Timestamp"] / 1000

In [7]:
import hashlib

prefix = "tracepoint:syscalls:sys_enter_"
flow_data["Syscall"] = flow_data["Syscall"].str.replace(prefix, "", regex=False)

flow_data["Syscall"] = flow_data["Syscall"].map(syscall_vocab)

def consistent_hash(command, table_size=10000):
    hash_object = hashlib.md5(command.encode())  # Use MD5 for consistent hashing
    hash_value = int(hash_object.hexdigest(), 16)  # Convert hash to an integer
    return hash_value % table_size

flow_data["Comm"] = flow_data["Comm"].map(lambda cmd: consistent_hash(cmd))

In [8]:
flow_data = flow_data.dropna()

In [9]:
flow_data.head()

Unnamed: 0,Timestamp,PID,Comm,Syscall
0,11156.347,1605,503,232.0
1,11156.348,1605,503,232.0
2,11156.348,1605,503,232.0
3,11156.348,1605,503,232.0
4,11156.348,1605,503,232.0


In [10]:
flow_data.tail()

Unnamed: 0,Timestamp,PID,Comm,Syscall
11264388,21079.22,829,8116,1.0
11264389,21079.231,833,8520,1.0
11264390,21079.231,833,8520,1.0
11264391,21079.231,833,8520,1.0
11264392,21079.231,833,8520,1.0


In [11]:
flow_data['PID'].nunique()

260

In [12]:
flow_data['PID'].max()

np.int64(3610)

In [13]:
flow_data['Comm'].max()

np.int64(9983)

In [14]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from tqdm import tqdm

labels = ["normal"] * flow_data.shape[0]
labels = pd.Series(labels)

sequence_length = 50
batch_size = 32

# Sliding window function
def create_sliding_windows(data, labels, window_size=300, step_size=150):
    windows = []
    labels_mode = []

    for start_idx in tqdm(range(0, len(data) - window_size + 1, step_size)):
        window = data.iloc[start_idx:start_idx + window_size]
        label_window = labels[start_idx:start_idx + window_size]
        windows.append(window)
        labels_mode.append(label_window.mode()[0])  # Majority label in the window

    return windows, labels_mode

# Map labels to numerical values
windows, window_labels = create_sliding_windows(flow_data, labels, sequence_length, sequence_length//2)

label_mapping = {"normal": 0, "malicious": 1}
numerical_window_labels = [label_mapping[label] for label in window_labels]

num_classes = len(label_mapping)
one_hot_labels = np.eye(num_classes)[numerical_window_labels]  # One-hot encoding

100%|██████████| 449028/449028 [00:28<00:00, 15722.02it/s]


In [15]:
print(windows[1])

    Timestamp   PID  Comm  Syscall
25  11156.355  1605   503    232.0
26  11156.355  1605   503    232.0
27  11156.355  1605   503    232.0
28  11156.356  1605   503    232.0
29  11156.356  1605   503    232.0
30  11156.356  1605   503    232.0
31  11156.356  1605   503    232.0
32  11156.357  1605   503    232.0
33  11156.357  1605   503    232.0
34  11156.357  1605   503    232.0
35  11156.358  1605   503    232.0
36  11156.358  1605   503    232.0
37  11156.358  1605   503    232.0
38  11156.358  1605   503    232.0
39  11156.359  1605   503    232.0
40  11156.359  1605   503    232.0
41  11156.359  1605   503    232.0
42  11156.360  1605   503    232.0
43  11156.360  1605   503    232.0
44  11156.360  1605   503    232.0
45  11156.361  1605   503    232.0
46  11156.361  1605   503    232.0
47  11156.361  1605   503    232.0
48  11156.362  1605   503    232.0
49  11156.362  1605   503    232.0
50  11156.362  1605   503    232.0
51  11156.363  1605   503    232.0
52  11156.363  1605 

In [16]:
import numpy as np
import torch
from torch.utils.data import Dataset

# Prepare data
train_size = int(0.9 * len(windows))  # First 90% for training
val_size = len(windows) - train_size  # Last 10% for validation

train_windows = windows[:train_size]
val_windows = windows[train_size:]

train_labels = one_hot_labels[:train_size]
val_labels = one_hot_labels[train_size:]

import copy
def inject_anomalies(window, label, anomaly_prob=0.2):
    """Inject anomalies in a contiguous portion of the window."""
    if np.random.rand() > anomaly_prob:
        return window, [0, 0, 1], 2

    anomalous_window = window.copy()
    anomaly_type = np.random.choice([1, 0])
    
    # Determine the size of the anomalous portion (between 10% and 50% of window)
    window_size = window.shape[0]
    anomaly_length = np.random.randint(window_size // 10, window_size // 2)
    
    # Randomly choose whether to inject from start or end
    from_start = np.random.choice([True, False])
    
    if from_start:
        anomaly_slice = slice(0, anomaly_length)
    else:
        anomaly_slice = slice(window_size - anomaly_length, window_size)
    
    if anomaly_type == 1:  # APT
        # Modify only comm_id and syscall in the chosen slice
        anomalous_window[anomaly_slice, 2] = np.random.randint(10000, 10101, size=anomaly_length)  # Modify comm_id
        anomalous_window[anomaly_slice, 3] = np.random.randint(0, 361, size=anomaly_length)  # Modify syscall
        label = [0, 1, 0]
    elif anomaly_type == 0:  # Blatant
        # Modify PID, comm_id, and syscall in the chosen slice
        anomalous_window[anomaly_slice, 1] = np.random.randint(10000, 10101, size=anomaly_length)  # Modify PID
        anomalous_window[anomaly_slice, 2] = np.random.randint(10000, 10101, size=anomaly_length)  # Modify comm_id
        anomalous_window[anomaly_slice, 3] = np.random.randint(0, 361, size=anomaly_length)  # Modify syscall
        label = [1, 0, 0]

    return anomalous_window, label, anomaly_type
   

def augment_windows(windows, labels, anomaly_prob=0.2):
    """
    Augments syscall windows by injecting anomalies based on a given probability.

    Args:
        windows (list or np.ndarray): List/array of syscall windows.
        labels (list or np.ndarray): List/array of one-hot encoded labels.
        inject_anomalies_fn (function): Function to inject anomalies.
        anomaly_prob (float): Probability of injecting anomalies.

    Returns:
        tuple: Augmented windows and labels (original + anomalous).
    """
    augmented_windows = []
    augmented_labels = []
    augmentation_types = []

    for window, label in tqdm(zip(windows, labels), total=len(windows), desc="Augmenting Data"):
        # Original window and label
        # augmented_windows.append(window)
        # augmented_labels.append(label)
        window = np.array(window)
        label = np.array(label)
        # Inject anomalies with the specified probability
        anomalous_window, anomalous_label, anomaly_type = inject_anomalies(window.copy(), label.copy(), anomaly_prob=anomaly_prob)
        augmented_windows.append(anomalous_window)
        augmented_labels.append(anomalous_label)
        augmentation_types.append(anomaly_type)

    return np.array(augmented_windows), np.array(augmented_labels), np.array(augmentation_types)

train_windows_augmented, train_labels_augmented, train_augmentation_types = augment_windows(
    train_windows,
    train_labels,
    anomaly_prob=0.5  # 50% anomaly probability
)

val_windows_augmented, val_labels_augmented, val_augmentation_types = augment_windows(
    val_windows,
    val_labels,
    anomaly_prob=0.5
)

print("\nOriginal Training Windows:", len(train_windows))
print("Augmented Training Windows:", len(train_windows_augmented))
print("Original Validation Windows:", len(val_windows))
print("Augmented Validation Windows:", len(val_windows_augmented))

Augmenting Data:   0%|          | 0/404125 [00:00<?, ?it/s]

Augmenting Data: 100%|██████████| 404125/404125 [00:15<00:00, 25571.65it/s]
Augmenting Data: 100%|██████████| 44903/44903 [00:01<00:00, 28934.71it/s]


Original Training Windows: 404125
Augmented Training Windows: 404125
Original Validation Windows: 44903
Augmented Validation Windows: 44903





In [17]:
class SyscallTokenizer:
    def __init__(self):
        # Basic vocabulary
        self.vocab = {
            '[PAD]': 0,
            '[CLS]': 1,
            '[SEP]': 2,
            '[MASK]': 3,  # Needed for MLM
            'PID': 4,
            'Comm': 5,
            'Syscall': 6,
            ':': 7,
            '0': 8, '1': 9, '2': 10, '3': 11, '4': 12,
            '5': 13, '6': 14, '7': 15, '8': 16, '9': 17
        }
        self.inverse_vocab = {v: k for k, v in self.vocab.items()}
        self.vocab_size = len(self.vocab)
        
    def tokenize(self, text):
        """Convert text to list of tokens"""
        tokens = []
        current_number = ''
        splits = text.strip().split()
        # print(f"splits: {splits}")
        for token in splits:
            if token.isdigit():
                current_number += token
            else:
                # Process any accumulated number
                if current_number:
                    tokens.extend(list(current_number))
                    current_number = ''
                
                # Process non-digit char
                if token.isspace():
                    continue
                # if char in self.vocab:
                if token[:-1] in self.vocab:
                    tokens.append(token[:-1])
                    tokens.append(token[-1])
            # print(f"tokens: {tokens}")
            # print(f"current_number: {current_number}")
        # Handle any remaining number
        if current_number:
            tokens.extend(list(current_number))
            
        return tokens
    
    def encode(self, text, max_length=3072, padding=True):
        """Convert text to token IDs"""
        tokens = ['[CLS]'] + self.tokenize(text) + ['[SEP]']
        ids = [self.vocab[token] for token in tokens]
        
        if padding and len(ids) < max_length:
            ids = ids + [self.vocab['[PAD]']] * (max_length - len(ids))
        
        attention_mask = [1] * len(tokens) + [0] * (max_length - len(tokens)) if padding else [1] * len(tokens)
        
        return {
            'input_ids': torch.tensor([ids]),
            'attention_mask': torch.tensor([attention_mask])
        }
    
    def decode(self, ids):
        """Convert token IDs back to text"""
        return ' '.join(self.inverse_vocab[id.item()] for id in ids if id != self.vocab['[PAD]'])

class SyscallDataset(Dataset):
    def __init__(self, windows, labels, types, tokenizer, max_length=3072):
        self.windows = windows
        self.labels = labels
        self.types = types
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.preprocessed_windows = []
        self._preprocess_windows()
    
    def _preprocess_windows(self):
        for window in tqdm(self.windows, desc="Preprocessing Windows"):
            # Format each sequence
            sequence = " ".join([
                f"PID: {' '.join(str(int(row[1])))} "
                f"Comm: {' '.join(str(int(row[2])))} "
                f"Syscall: {' '.join(str(int(row[3])))}"
                for row in window
            ])
            print(sequence)
            encoded = self.tokenizer.encode(
                sequence,
                max_length=self.max_length,
                padding=True
            )
            print(encoded)
            self.preprocessed_windows.append(encoded)
    
    def __len__(self):
        return len(self.windows)
    
    def __getitem__(self, idx):
        return {
            "input_ids": self.preprocessed_windows[idx]["input_ids"].squeeze(0),
            "attention_mask": self.preprocessed_windows[idx]["attention_mask"].squeeze(0),
            "label": torch.tensor(self.labels[idx], dtype=torch.float),
            "type": torch.tensor(self.types[idx], dtype=torch.int)
        }

In [18]:
# device = "mps"
# # Initialize dataset and dataloader
# tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
# train_dataset = SyscallDataset(train_windows_augmented, train_labels_augmented, train_augmentation_types, tokenizer)
# val_dataset = SyscallDataset(val_windows_augmented, val_labels_augmented, val_augmentation_types, tokenizer)

# train_loader = DataLoader(train_dataset, batch_size=16, shuffle=False)
# val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# # Initialize model, optimizer, and loss function
# model = LMCritic("distilbert-base-uncased", lm_hidden_dim=768, critic_hidden_dim=256, output_dim=3).to(device)
# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# criterion = nn.CrossEntropyLoss()

In [19]:
from transformers import AutoTokenizer

# Example windows: List of sequences with [PID, Comm, Syscall]
ex_windows = [
    [[1, 829, 8116, 1.0], [1, 833, 8520, 1.0], [1, 833, 8520, 2.0]],
    [[1, 830, 8117, 1.0], [1, 834, 8521, 3.0], [1, 835, 8522, 4.0]],
]

labels = [1, 0]  # Non-anomalous and anomalous labels
types = [0, 1]   # Auxiliary data or type information

# Initialize the tokenizer
tokenizer = SyscallTokenizer()

# Create the dataset
dataset = SyscallDataset(ex_windows, labels, types, tokenizer, max_length=128)


Preprocessing Windows: 100%|██████████| 2/2 [00:00<00:00, 459.15it/s]

PID: 8 2 9 Comm: 8 1 1 6 Syscall: 1 PID: 8 3 3 Comm: 8 5 2 0 Syscall: 1 PID: 8 3 3 Comm: 8 5 2 0 Syscall: 2
{'input_ids': tensor([[ 1,  4,  7, 16, 10, 17,  5,  7, 16,  9,  9, 14,  6,  7,  9,  4,  7, 16,
         11, 11,  5,  7, 16, 13, 10,  8,  6,  7,  9,  4,  7, 16, 11, 11,  5,  7,
         16, 13, 10,  8,  6,  7, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 




In [20]:
# Example: Access a single data point
sample = dataset[0]

print("Input IDs:", sample["input_ids"])
print("Attention Mask:", sample["attention_mask"])
print("Label:", sample["label"])
print("Type:", sample["type"])
dataset.preprocessed_windows[0]

Input IDs: tensor([ 1,  4,  7, 16, 10, 17,  5,  7, 16,  9,  9, 14,  6,  7,  9,  4,  7, 16,
        11, 11,  5,  7, 16, 13, 10,  8,  6,  7,  9,  4,  7, 16, 11, 11,  5,  7,
        16, 13, 10,  8,  6,  7, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0])
Attention Mask: tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

{'input_ids': tensor([[ 1,  4,  7, 16, 10, 17,  5,  7, 16,  9,  9, 14,  6,  7,  9,  4,  7, 16,
          11, 11,  5,  7, 16, 13, 10,  8,  6,  7,  9,  4,  7, 16, 11, 11,  5,  7,
          16, 13, 10,  8,  6,  7, 10,  2,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
           0,  0]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
          1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [21]:
from transformers import DistilBertTokenizer, DistilBertForMaskedLM

# Initialize the tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased")

# Example sequence
sequence = "PID: 1605 Comm: 503 Syscall: 232"

# Tokenize the sequence

In [23]:
train_windows_numpy = np.array(train_windows)
val_windows_numpy = np.array(val_windows)

In [24]:
train_sequences = [" ".join([
    f"PID: {str(int(row[1]))} "
    f"Comm: {str(int(row[2]))} "
    f"Syscall: {str(int(row[3]))}"
    for row in window
]) for window in tqdm(train_windows_numpy)]
train_sequences[0]
val_sequences = [" ".join([
    f"PID: {str(int(row[1]))} "
    f"Comm: {str(int(row[2]))} "
    f"Syscall: {str(int(row[3]))}"
    for row in window
]) for window in tqdm(val_windows_numpy)]
val_sequences[0]

100%|██████████| 404125/404125 [00:14<00:00, 27961.49it/s]
100%|██████████| 44903/44903 [00:01<00:00, 29172.36it/s]


'PID: 829 Comm: 8116 Syscall: 7 PID: 829 Comm: 8116 Syscall: 16 PID: 829 Comm: 8116 Syscall: 47 PID: 829 Comm: 8116 Syscall: 1 PID: 829 Comm: 8116 Syscall: 1 PID: 829 Comm: 8116 Syscall: 1 PID: 829 Comm: 8116 Syscall: 0 PID: 829 Comm: 8116 Syscall: 7 PID: 736 Comm: 9057 Syscall: 100 PID: 736 Comm: 9057 Syscall: 100 PID: 736 Comm: 9057 Syscall: 100 PID: 736 Comm: 9057 Syscall: 100 PID: 736 Comm: 9057 Syscall: 100 PID: 736 Comm: 9057 Syscall: 7 PID: 829 Comm: 8116 Syscall: 1 PID: 829 Comm: 8116 Syscall: 1 PID: 829 Comm: 8116 Syscall: 0 PID: 829 Comm: 8116 Syscall: 7 PID: 1742 Comm: 6962 Syscall: 47 PID: 1742 Comm: 6962 Syscall: 7 PID: 829 Comm: 8116 Syscall: 1 PID: 829 Comm: 8116 Syscall: 1 PID: 829 Comm: 8116 Syscall: 1 PID: 829 Comm: 8116 Syscall: 1 PID: 829 Comm: 8116 Syscall: 1 PID: 829 Comm: 8116 Syscall: 1 PID: 829 Comm: 8116 Syscall: 0 PID: 829 Comm: 8116 Syscall: 7 PID: 829 Comm: 8116 Syscall: 1 PID: 829 Comm: 8116 Syscall: 1 PID: 829 Comm: 8116 Syscall: 1 PID: 829 Comm: 8116 Sys

In [25]:
len(train_sequences[0])

1649

In [26]:
inputs = tokenizer(train_sequences[0], return_tensors="pt", max_length=768, padding="max_length", truncation=True)
print(inputs)
print(tokenizer.decode(inputs.input_ids[0]))

{'input_ids': tensor([[  101, 14255,  2094,  1024, 28202,  4012,  2213,  1024,  2753,  2509,
         25353, 15782,  3363,  1024, 20666, 14255,  2094,  1024, 28202,  4012,
          2213,  1024,  2753,  2509, 25353, 15782,  3363,  1024, 20666, 14255,
          2094,  1024, 28202,  4012,  2213,  1024,  2753,  2509, 25353, 15782,
          3363,  1024, 20666, 14255,  2094,  1024, 28202,  4012,  2213,  1024,
          2753,  2509, 25353, 15782,  3363,  1024, 20666, 14255,  2094,  1024,
         28202,  4012,  2213,  1024,  2753,  2509, 25353, 15782,  3363,  1024,
         20666, 14255,  2094,  1024, 28202,  4012,  2213,  1024,  2753,  2509,
         25353, 15782,  3363,  1024, 20666, 14255,  2094,  1024, 28202,  4012,
          2213,  1024,  2753,  2509, 25353, 15782,  3363,  1024, 20666, 14255,
          2094,  1024, 28202,  4012,  2213,  1024,  2753,  2509, 25353, 15782,
          3363,  1024, 20666, 14255,  2094,  1024, 28202,  4012,  2213,  1024,
          2753,  2509, 25353, 15782,  

In [27]:
import torch
print("MPS available:", torch.backends.mps.is_available())
print("MPS built:", torch.backends.mps.is_built())

MPS available: True
MPS built: True


In [28]:
!pip3 install --pre --force-reinstall torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu

Looking in indexes: https://download.pytorch.org/whl/nightly/cpu
Collecting torch
  Using cached https://download.pytorch.org/whl/nightly/cpu/torch-2.6.0.dev20241209-cp312-none-macosx_11_0_arm64.whl (66.2 MB)
Collecting torchvision
  Using cached https://download.pytorch.org/whl/nightly/cpu/torchvision-0.20.0.dev20241209-cp312-cp312-macosx_11_0_arm64.whl (1.9 MB)
Collecting torchaudio
  Using cached https://download.pytorch.org/whl/nightly/cpu/torchaudio-2.5.0.dev20241209-cp312-cp312-macosx_11_0_arm64.whl (1.8 MB)
Collecting filelock (from torch)
  Using cached https://download.pytorch.org/whl/nightly/filelock-3.16.1-py3-none-any.whl (16 kB)
Collecting typing-extensions>=4.10.0 (from torch)
  Using cached https://download.pytorch.org/whl/nightly/typing_extensions-4.12.2-py3-none-any.whl (37 kB)
Collecting networkx (from torch)
  Using cached https://download.pytorch.org/whl/nightly/networkx-3.4.2-py3-none-any.whl (1.7 MB)
Collecting jinja2 (from torch)
  Using cached https://download.p

In [29]:
from transformers import DistilBertTokenizer, DistilBertForMaskedLM
from torch.utils.data import Dataset, DataLoader
import torch
from tqdm import tqdm

class SyscallMLMDataset(Dataset):
    def __init__(self, sequences, tokenizer, max_length=512):
        self.sequences = sequences
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.sequences)
    
    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        
        # Tokenize sequence
        encoding = self.tokenizer(
            sequence,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Create MLM inputs (randomly mask tokens)
        input_ids = encoding['input_ids'].squeeze(0)
        attention_mask = encoding['attention_mask'].squeeze(0)
        
        # Create MLM labels (copy of input_ids)
        labels = input_ids.clone()
        
        # Randomly mask tokens for MLM
        probability_matrix = torch.full(labels.shape, 0.15)
        special_tokens_mask = self.tokenizer.get_special_tokens_mask(labels, already_has_special_tokens=True)
        probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        
        # Set labels for unmasked tokens to -100 (ignored in loss computation)
        labels[~masked_indices] = -100
        
        # 80% of the time, replace masked tokens with [MASK]
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        input_ids[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
        
        # 10% of the time, replace masked tokens with random words
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        input_ids[indices_random] = random_words[indices_random]
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels
        }

from itertools import cycle

def pretrain_mlm(train_sequences, val_sequences=None, num_epochs=3):
    # Initialize tokenizer and model
    tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
    model = DistilBertForMaskedLM.from_pretrained('distilbert-base-uncased')
    
    # Create datasets and dataloaders
    train_dataset = SyscallMLMDataset(train_sequences, tokenizer)
    train_loader = DataLoader(
        train_dataset,
        batch_size=32,
        shuffle=True,
        pin_memory=True
    )
    
    val_loader = None
    if val_sequences:
        val_dataset = SyscallMLMDataset(val_sequences, tokenizer)
        val_loader = DataLoader(
            val_dataset,
            batch_size=32,
            shuffle=False,
            pin_memory=True
        )
        val_iterator = cycle(val_loader)
    
    # Setup training
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    device = "cpu"
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    
    # Training loop
    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0
        total_train_accuracy = 0
        total_val_loss = 0
        total_val_accuracy = 0
        val_steps = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        for batch_idx, train_batch in enumerate(progress_bar):
            # Training step
            optimizer.zero_grad()
            train_inputs = {k: v.to(device) for k, v in train_batch.items()}
            train_outputs = model(**train_inputs)
            train_loss = train_outputs.loss
            train_accuracy = (train_outputs.logits.argmax(dim=-1) == train_inputs['labels']).float().mean()
            
            train_loss.backward()
            optimizer.step()
            
            total_train_loss += train_loss.item()
            total_train_accuracy += train_accuracy.item()
            
            # Validation step (if validation data exists)
            if val_loader:
                model.eval()
                with torch.no_grad():
                    val_batch = next(val_iterator)
                    val_inputs = {k: v.to(device) for k, v in val_batch.items()}
                    val_outputs = model(**val_inputs)
                    val_loss = val_outputs.loss
                    val_accuracy = (val_outputs.logits.argmax(dim=-1) == val_inputs['labels']).float().mean()
                    
                    total_val_loss += val_loss.item()
                    total_val_accuracy += val_accuracy.item()
                    val_steps += 1
                model.train()
            
            # Update progress bar
            avg_train_loss = total_train_loss / (batch_idx + 1)
            avg_train_acc = total_train_accuracy / (batch_idx + 1)
            postfix = {
                'train_loss': f"{avg_train_loss:.4f}",
                'train_acc': f"{avg_train_acc:.4f}"
            }
            
            if val_loader:
                avg_val_loss = total_val_loss / val_steps
                avg_val_acc = total_val_accuracy / val_steps
                postfix.update({
                    'val_loss': f"{avg_val_loss:.4f}",
                    'val_acc': f"{avg_val_acc:.4f}"
                })
            
            progress_bar.set_postfix(postfix)
        
        # Print epoch summary
        print(f"\nEpoch {epoch + 1} Summary:")
        print(f"Training - Loss: {avg_train_loss:.4f}, Accuracy: {avg_train_acc:.4f}")
        if val_loader:
            print(f"Validation - Loss: {avg_val_loss:.4f}, Accuracy: {avg_val_acc:.4f}")
    
    return model, tokenizer

# Use your existing sequences
# Assuming train_sequences exists from your previous code
pretrained_model, tokenizer = pretrain_mlm(train_sequences, val_sequences, num_epochs=3)

# Save the pretrained model and tokenizer
pretrained_model.save_pretrained("syscall_pretrained_distilbert")
tokenizer.save_pretrained("syscall_pretrained_distilbert")

Epoch 1:   0%|          | 5/12629 [01:43<72:15:55, 20.61s/it, train_loss=0.7276, train_acc=0.1291, val_loss=0.4789, val_acc=0.1382]


KeyboardInterrupt: 

In [26]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
tokens = tokenizer.tokenize("100023")
print(tokens)

['1000', '##23']


In [27]:
from transformers import DistilBertConfig, DistilBertForMaskedLM
from torch.utils.data import DataLoader
import random

class SyscallPretrainingDataset(Dataset):
    def __init__(self, raw_windows, tokenizer, max_length=3072, mlm_probability=0.15, init=True):
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.mlm_probability = mlm_probability
        self.preprocessed_input_ids = []
        self.preprocessed_attention_masks = []
        self.preprocessed_labels = []
        if init:
            # Process raw windows (no augmentation)
            for window in tqdm(raw_windows, desc="Preprocessing Windows"):
                sequence = " ".join([
                    f"PID:{' '.join(str(int(row[1])))} "
                    f"Comm:{' '.join(str(int(row[2])))} "
                    f"Syscall:{' '.join(str(int(row[3])))}"
                    for row in window
                ])
                # print(f"sequence: {sequence}")
                encoded = self.tokenizer.encode(
                    sequence,
                    max_length=self.max_length,
                    padding=True
                )
                # print(f"encoded: {encoded}")
                input_ids = encoded["input_ids"].squeeze(0)
                attention_mask = encoded["attention_mask"].squeeze(0)
                input_ids, labels = self._mask_tokens(input_ids)
                self.preprocessed_input_ids.append(input_ids)
                self.preprocessed_attention_masks.append(attention_mask)
                self.preprocessed_labels.append(labels)
    
    def __len__(self):
        return len(self.preprocessed_labels)
    
    def _mask_tokens(self, inputs):
        """

        Applies MLM masking to input tokens
        
        """
        labels = inputs.clone()
        input_length = len(inputs)
        
        special_tokens_mask = torch.tensor([
            1 if idx in [self.tokenizer.vocab['[PAD]'], 
                        self.tokenizer.vocab['[CLS]'], 
                        self.tokenizer.vocab['[SEP]']]
            else 0
            for idx in inputs
        ], dtype=torch.bool)
        
        # make it so theres 0 chance of masking special tokens
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        
        # creates a mask of the same shape as inputs
        # [True, False, True, False, True]
        masked_indices = torch.bernoulli(probability_matrix).bool()
        # invert the mask to assign -100 to the unmasked indices
        labels[~masked_indices] = -100  # only compute loss on masked tokens
        
        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.8)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.vocab['[MASK]']
        
        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer.vocab), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]
        
        # 10% of the time, we keep the original word but include it in the loss so the model learns to ignore it.

        return inputs, labels
    
    def __getitem__(self, idx):
        input_ids = self.preprocessed_input_ids[idx]
        attention_mask = self.preprocessed_attention_masks[idx]
        labels = self.preprocessed_labels[idx]
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels
        }

In [28]:
class SyscallPretrainingDataset(Dataset):
    def __init__(self, raw_windows, tokenizer, max_length=128, mlm_probability=0.15):
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        print("Preprocessing windows...")
        # Track sequence lengths
        max_seq_length = 0
        self.sequences = [
            " ".join([
                f"PID:{' '.join(str(int(row[1])))} "
                f"Comm:{' '.join(str(int(row[2])))} "
                f"Syscall:{' '.join(str(int(row[3])))}"
                for row in window
            ])
            for window in tqdm(raw_windows)
        ]
        
        print("Encoding and masking sequences...")
        self.preprocessed_data = []
        progress_bar = tqdm(self.sequences, desc="Encoding and Masking Sequences")
        for seq in progress_bar:
            # Encode without padding first to get true length
            tokens = self.tokenizer.tokenize(seq)
            seq_length = len(tokens) + 2  # +2 for [CLS] and [SEP]
            max_seq_length = seq_length if seq_length > max_seq_length else max_seq_length
            progress_bar.set_postfix({'max_seq_length': max_seq_length})
            # Now encode with padding
            encoded = self.tokenizer.encode(
                seq,
                max_length=max_length,
                padding=True
            )
            
            input_ids = torch.tensor(encoded['input_ids'], dtype=torch.long).squeeze(0)
            attention_mask = torch.tensor(encoded['attention_mask'], dtype=torch.long).squeeze(0)
            
            masked_input, labels = self._mask_tokens(
                input_ids, 
                mlm_probability
            )
            
            self.preprocessed_data.append({
                "input_ids": masked_input,
                "attention_mask": attention_mask,
                "labels": labels
            })
    
    @staticmethod
    def _mask_tokens(inputs, mlm_probability):
        """Static method for masking - only called once during preprocessing"""
        labels = inputs.clone()
        
        special_tokens_mask = torch.tensor([
            1 if idx in [0, 1, 2]
            else 0
            for idx in inputs
        ], dtype=torch.bool)
        
        probability_matrix = torch.full(labels.shape, mlm_probability)
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100
        
        # 80% [MASK], 10% random, 10% unchanged
        rand = torch.rand(labels.shape)
        mask_indices = (rand < 0.8) & masked_indices
        random_indices = (rand >= 0.8) & (rand < 0.9) & masked_indices
        
        inputs_masked = inputs.clone()
        inputs_masked[mask_indices] = 3  # [MASK] token
        random_words = torch.randint(4, 18, labels.shape)  # Random tokens (excluding special tokens)
        inputs_masked[random_indices] = random_words[random_indices]
        
        return inputs_masked, labels
    
    def __len__(self):
        return len(self.preprocessed_data)
    
    def __getitem__(self, idx):
        return self.preprocessed_data[idx]

def pretrain_model(pretrain_dataset, num_epochs=3):
    config = DistilBertConfig(
        vocab_size=18,
        max_position_embeddings=128,
        num_attention_heads=4,
        num_hidden_layers=4,
        hidden_size=128,
        intermediate_size=512
    )
    
    model = DistilBertForMaskedLM(config)
    train_loader = DataLoader(
        pretrain_dataset,
        batch_size=8,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    model.train()
    for epoch in range(num_epochs):
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")
        total_loss = 0
        
        for batch in progress_bar:
            optimizer.zero_grad()
            inputs = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**inputs)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': loss.item()})
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1} average loss: {avg_loss}")
    
    return model

In [36]:
type(train_windows)

list

In [38]:
train_windows[1:]

[    Timestamp   PID  Comm  Syscall
 25  11156.355  1605   503    232.0
 26  11156.355  1605   503    232.0
 27  11156.355  1605   503    232.0
 28  11156.356  1605   503    232.0
 29  11156.356  1605   503    232.0
 30  11156.356  1605   503    232.0
 31  11156.356  1605   503    232.0
 32  11156.357  1605   503    232.0
 33  11156.357  1605   503    232.0
 34  11156.357  1605   503    232.0
 35  11156.358  1605   503    232.0
 36  11156.358  1605   503    232.0
 37  11156.358  1605   503    232.0
 38  11156.358  1605   503    232.0
 39  11156.359  1605   503    232.0
 40  11156.359  1605   503    232.0
 41  11156.359  1605   503    232.0
 42  11156.360  1605   503    232.0
 43  11156.360  1605   503    232.0
 44  11156.360  1605   503    232.0
 45  11156.361  1605   503    232.0
 46  11156.361  1605   503    232.0
 47  11156.361  1605   503    232.0
 48  11156.362  1605   503    232.0
 49  11156.362  1605   503    232.0
 50  11156.362  1605   503    232.0
 51  11156.363  1605   503  

In [35]:
train_windows_numpy = train_windows.values
val_windows_numpy = val_windows.values

AttributeError: 'list' object has no attribute 'values'

In [32]:
train_sequences = [
    " ".join([
        f"PID: {str(int(row[1]))} "
        f"Comm: {str(int(row[2]))} "
        f"Syscall: {str(int(row[3]))}"
        for row in window
    ])
    for window in tqdm(train_windows_numpy)
]
val_sequences = [
    " ".join([
        f"PID: {str(int(row[1]))} "
        f"Comm: {str(int(row[2]))} "
        f"Syscall: {str(int(row[3]))}"
        for row in window
    ])
    for window in tqdm(val_windows_numpy)
]

  0%|          | 0/404125 [00:00<?, ?it/s]


ValueError: invalid literal for int() with base 10: 'i'

In [29]:
import os
import pickle
pretrain_dataset = None
if not os.path.exists('pretrain_dataset_tokenized.pkl'):
    pretrain_dataset = SyscallPretrainingDataset(train_windows, tokenizer, max_length=448)
    with open('pretrain_dataset_tokenized.pkl', 'wb') as f:
        pickle.dump(pretrain_dataset, f)
else:
    with open('pretrain_dataset_tokenized.pkl', 'rb') as f:
        pretrain_dataset = pickle.load(f)  # Loads back as SyscallPretrainingDataset with all data

In [31]:
sentences = [sentence for sentence in tqdm(pretrain_dataset.sequences, desc="Extracting sentences")]

AttributeError: 'SyscallPretrainingDataset' object has no attribute 'sequences'

In [32]:
from datetime import datetime
import time
import torch
import torch.multiprocessing as mp
mp.set_start_method('spawn', force=True)  # Add this at the start of your script

def pretrain_model(dataset):
    config = DistilBertConfig(
        vocab_size=18, 
        max_position_embeddings=448,  # Matches your longest sequence
        num_attention_heads=4,
        num_hidden_layers=4,
        hidden_size=128,  # Reduced from 256 to save memory
        intermediate_size=64
    )

    model = DistilBertForMaskedLM(config)
    print(f"dataloader created at {datetime.now()}")
    # Smaller batch size to handle longer sequences
    pretrain_loader = DataLoader(
        dataset,
        batch_size=32,
        shuffle=True,
        num_workers=0,  # Add parallel data loading
        pin_memory=True  # Faster data transfer to device
    )
    print(f"dataloader finished at {datetime.now()}")
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.5)
    device = torch.device(torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu"))
    # Enable async data transfer
    device = torch.device("cpu")
    model.to(device)
    
    print(f"Training started on {device} at {datetime.now()}")
    print(f"Number of parameters: {sum(p.numel() for p in model.parameters())}")
    model.train()
    try:
        for epoch in range(10):
            total_loss = 0
            total_accuracy = 0
            progress_bar = tqdm(pretrain_loader, desc=f"Epoch {epoch+1}")
            for idx, batch in enumerate(progress_bar):
                if torch.backends.mps.is_available():
                    torch.mps.empty_cache()
                optimizer.zero_grad()
                inputs = {k: v.to(device) for k, v in batch.items()}
                outputs = model(**inputs)
                loss = outputs.loss
                predictions = outputs.logits.argmax(dim=-1)  # [batch, seq_len]
                labels = inputs["labels"]                    # [batch, seq_len]
                
                # Only consider positions where labels != -100
                valid_positions = (labels != -100)
                if valid_positions.any():
                    accuracy = (predictions[valid_positions] == labels[valid_positions]).float().mean()
                else:
                    accuracy = torch.tensor(0.0)
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                total_accuracy += accuracy.item()
                progress_bar.set_postfix({
                    'loss': f"{total_loss/(idx+1):.4f}", 
                    'acc': f"{total_accuracy/(idx+1):.4f}"
                })
                
            avg_loss = total_loss / len(pretrain_loader)
            avg_accuracy = total_accuracy / len(pretrain_loader)
            print(f"Epoch {epoch+1} average loss: {avg_loss:.4f}")
            print(f"Epoch {epoch+1} average accuracy: {avg_accuracy:.4f}")
    except KeyboardInterrupt:
        print("\nTraining interrupted by user")
        return model

model = pretrain_model(pretrain_dataset)

dataloader created at 2024-12-08 02:52:11.561297
dataloader finished at 2024-12-08 02:52:11.561548
Training started on cpu at 2024-12-08 02:52:11.562600
Number of parameters: 3501458


Epoch 1: 100%|██████████| 14033/14033 [6:04:39<00:00,  1.56s/it, loss=0.3332, acc=0.8840]  


Epoch 1 average loss: 0.3332
Epoch 1 average accuracy: 0.8840


Epoch 2:  13%|█▎        | 1766/14033 [42:56<4:58:14,  1.46s/it, loss=0.0576, acc=0.9815]


Training interrupted by user





In [33]:
model.save_pretrained("pretrained_model_distilbert_mlm")

AttributeError: 'SyscallTokenizer' object has no attribute 'save_pretrained'

In [111]:
class SyscallCriticDataset(Dataset):
    def __init__(self, raw_windows, labels, types):
        self.labels = labels
        self.types = types
        self.sequences = [
            " ".join([
                f"PID:{' '.join(str(int(row[1])))} "
                f"Comm:{' '.join(str(int(row[2])))} "
                f"Syscall:{' '.join(str(int(row[3])))}"
                for row in window
            ])
            for window in tqdm(raw_windows, desc="Preprocessing windows")
        ]
    def __len__(self):
        return len(self.sequences)
    def __getitem__(self, idx):
        return {
            "sequence": self.sequences[idx],
            "label": self.labels[idx],
            "type": self.types[idx]
        }
        # print("Encoding and masking sequences...")
        # self.preprocessed_data = []
        # progress_bar = tqdm(self.sequences, desc="Encoding and Masking Sequences")
        # for seq in progress_bar:
        #     # Encode without padding first to get true length
        #     tokens = self.tokenizer.tokenize(seq)

        #     # test for max length
        #     seq_length = len(tokens) + 2  # +2 for [CLS] and [SEP]
        #     max_seq_length = seq_length if seq_length > max_seq_length else max_seq_length
        #     progress_bar.set_postfix({'max_seq_length': max_seq_length})

        #     encoded = self.tokenizer.encode(
        #         seq,
        #         max_length=max_length,
        #         padding=True
        #     )
            
        #     input_ids = torch.tensor(encoded['input_ids'], dtype=torch.long).squeeze(0)
        #     attention_mask = torch.tensor(encoded['attention_mask'], dtype=torch.long).squeeze(0)
            
        #     masked_input, labels = self._mask_tokens(
        #         input_ids, 
        #         mlm_probability
        #     )
            
        #     self.preprocessed_data.append({
        #         "input_ids": masked_input,
        #         "attention_mask": attention_mask,
        #         "labels": labels
        #     })
    

In [112]:
sequence_dataset = SyscallCriticDataset(train_windows_augmented, train_labels_augmented, train_augmentation_types)

Preprocessing windows: 100%|██████████| 404125/404125 [00:22<00:00, 18104.64it/s]


In [139]:
class SyscallLMCritic:
    def __init__(self, pretrained_model, tokenizer, device='cpu'):
        self.model = pretrained_model.to(device)
        self.tokenizer = tokenizer
        self.device = device
        self.model.eval()
    
    def get_sequence_probability(self, sequence):
        """Get probability score from pretrained MLM"""
        with torch.no_grad():
            inputs = self.tokenizer.encode(sequence, max_length=448, padding=True)
            # print(inputs)
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            inputs['labels'] = inputs['input_ids']
            outputs = self.model(**inputs)
            # print(outputs)
            return -outputs.loss.item()  # log loss
        
    def parse_sequence(self, sequence):
        """Parse sequence into structured format using pattern matching"""
        parts = []
        current_label = None
        current_numbers = []
        
        tokens = sequence.split()
        for token in tokens:
            if token.endswith(':'):  # Found a label (PID:/Comm:/Syscall:)
                if current_label is not None:
                    parts.append((current_label, current_numbers))
                current_label = token
                current_numbers = []
            else:  # Found a number
                current_numbers.append(token)
        
        # Add the last group
        if current_label is not None:
            parts.append((current_label, current_numbers))
            
        return parts
    
    def generate_perturbations(self, sequence, n_samples=100):
        """Generate perturbations by only modifying numbers within valid ranges"""
        perturbations = []
        perturbations.append(sequence) # original sequence at index 0
        parts = self.parse_sequence(sequence)
        
        for _ in range(n_samples):
            perturbed_parts = []
            
            for label, numbers in parts:
                if label == 'PID:':
                    new_num = random.randint(0, 5000)
                    new_numbers = [d for d in str(new_num)]
                elif label == 'Comm:':
                    new_num = random.randint(0, 9999)
                    new_numbers = [d for d in str(new_num)]
                elif label == 'Syscall:':
                    new_num = random.randint(0, 360)
                    new_numbers = [d for d in str(new_num)]
                perturbed_parts.append((label, new_numbers))
            
            perturbed_sequence = ' '.join([
                label + ' ' + ' '.join(numbers)
                for label, numbers in perturbed_parts
            ])
            perturbations.append(perturbed_sequence)
        
        return perturbations
    
    def is_grammatical(self, sequence, verbose=False):
        """Implement local optimum criterion"""
        # print(f"sequence: {sequence}")
        perturbations = self.generate_perturbations(sequence)
        perturbations.append(sequence)
        # print(f"perturbations: {perturbations}")
        
        # Get probabilities
        probs = [self.get_sequence_probability(seq) for seq in perturbations]
        original_idx = 0
        best_idx = probs.index(max(probs))
        # Check if original sequence has highest probability
        if verbose:
            if probs[original_idx] == max(probs):
                print(f'Good! Your sequence log(p) = {float(probs[original_idx]):.3f}')
            else:
                print(f'Bad! Your sequence log(p) = {float(probs[original_idx]):.3f}')
                print(f'Neighbor sequence with highest log(p): {perturbations[best_idx]} (= {float(probs[best_idx]):.3f})')
        
        return probs[original_idx] == max(probs)  # min because we're using loss

In [140]:
critic = SyscallLMCritic(model, tokenizer)
critic.is_grammatical("PID: 1 2 3 4 Comm: 1 2 3 4 Syscall: 1 2 3")

False

In [141]:
sequence = "PID: 1 6 0 5 Comm: 5 0 3 Syscall: 2 3 2 PID: 1 6 0 5 Comm: 5 0 3 Syscall: 2 3 2 PID: 1 6 0 5 Comm: 5 0 3 Syscall: 2 3 2 PID: 1 6 0 5 Comm: 5 0 3 Syscall: 2 3 2 PID: 1 6 0 5 Comm: 5 0 3 Syscall: 2 3 2"

def parse_sequence(sequence):
    """Parse sequence into structured format using pattern matching"""
    parts = []
    current_label = None
    current_numbers = []
    
    tokens = sequence.split()
    for token in tokens:
        if token.endswith(':'):  # label (PID:/Comm:/Syscall:)
            if current_label is not None:
                parts.append((current_label, current_numbers))
            current_label = token
            current_numbers = []
        else:
            current_numbers.append(token)
    
    if current_label is not None:
        parts.append((current_label, current_numbers))
        
    return parts

parsed = parse_sequence(sequence)
parsed

[('PID:', ['1', '6', '0', '5']),
 ('Comm:', ['5', '0', '3']),
 ('Syscall:', ['2', '3', '2']),
 ('PID:', ['1', '6', '0', '5']),
 ('Comm:', ['5', '0', '3']),
 ('Syscall:', ['2', '3', '2']),
 ('PID:', ['1', '6', '0', '5']),
 ('Comm:', ['5', '0', '3']),
 ('Syscall:', ['2', '3', '2']),
 ('PID:', ['1', '6', '0', '5']),
 ('Comm:', ['5', '0', '3']),
 ('Syscall:', ['2', '3', '2']),
 ('PID:', ['1', '6', '0', '5']),
 ('Comm:', ['5', '0', '3']),
 ('Syscall:', ['2', '3', '2'])]

In [142]:
def generate_perturbations(sequence, n_samples=100):
    """Generate perturbations by only modifying numbers within valid ranges"""
    perturbations = []
    perturbations.append(sequence)
    parts = parse_sequence(sequence)
    print(f"parts: {parts}")
    for _ in range(n_samples):
        perturbed_parts = []
        
        for label, numbers in parts:
            if label == 'PID:':
                new_num = random.randint(0, 5000)
                new_numbers = [d for d in str(new_num)]
            elif label == 'Comm:':
                new_num = random.randint(0, 9999)
                new_numbers = [d for d in str(new_num)]
            elif label == 'Syscall:':
                new_num = random.randint(0, 360)
                new_numbers = [d for d in str(new_num)]
            perturbed_parts.append((label, new_numbers))
        
        # Reconstruct sequence
        perturbed_sequence = ' '.join([
            label + ' ' + ' '.join(numbers)
            for label, numbers in perturbed_parts
        ])
        perturbations.append(perturbed_sequence)

    return perturbations
perturbations = generate_perturbations(sequence)
perturbations

parts: [('PID:', ['1', '6', '0', '5']), ('Comm:', ['5', '0', '3']), ('Syscall:', ['2', '3', '2']), ('PID:', ['1', '6', '0', '5']), ('Comm:', ['5', '0', '3']), ('Syscall:', ['2', '3', '2']), ('PID:', ['1', '6', '0', '5']), ('Comm:', ['5', '0', '3']), ('Syscall:', ['2', '3', '2']), ('PID:', ['1', '6', '0', '5']), ('Comm:', ['5', '0', '3']), ('Syscall:', ['2', '3', '2']), ('PID:', ['1', '6', '0', '5']), ('Comm:', ['5', '0', '3']), ('Syscall:', ['2', '3', '2'])]


['PID: 1 6 0 5 Comm: 5 0 3 Syscall: 2 3 2 PID: 1 6 0 5 Comm: 5 0 3 Syscall: 2 3 2 PID: 1 6 0 5 Comm: 5 0 3 Syscall: 2 3 2 PID: 1 6 0 5 Comm: 5 0 3 Syscall: 2 3 2 PID: 1 6 0 5 Comm: 5 0 3 Syscall: 2 3 2',
 'PID: 1 1 Comm: 3 6 2 5 Syscall: 2 4 5 PID: 1 3 4 0 Comm: 2 5 8 2 Syscall: 2 3 4 PID: 1 1 8 4 Comm: 5 9 7 Syscall: 9 5 PID: 2 4 2 0 Comm: 5 8 3 2 Syscall: 1 8 2 PID: 3 4 6 4 Comm: 5 8 7 5 Syscall: 2 5 4',
 'PID: 1 7 2 9 Comm: 9 4 5 4 Syscall: 3 4 5 PID: 8 4 6 Comm: 3 5 3 8 Syscall: 1 1 6 PID: 4 2 7 Comm: 8 0 1 5 Syscall: 1 5 2 PID: 4 6 3 6 Comm: 5 2 2 1 Syscall: 2 9 7 PID: 2 3 7 3 Comm: 9 7 9 3 Syscall: 1 9 6',
 'PID: 2 9 5 8 Comm: 3 5 6 1 Syscall: 2 6 1 PID: 1 7 6 3 Comm: 3 7 1 7 Syscall: 1 2 9 PID: 4 1 5 7 Comm: 9 0 0 8 Syscall: 8 6 PID: 2 6 3 9 Comm: 4 3 3 5 Syscall: 2 6 3 PID: 2 6 0 1 Comm: 8 1 3 8 Syscall: 1 8 6',
 'PID: 4 2 9 6 Comm: 3 Syscall: 2 3 8 PID: 3 6 3 4 Comm: 7 4 7 8 Syscall: 1 6 0 PID: 3 0 5 7 Comm: 7 9 5 7 Syscall: 3 8 PID: 3 2 5 5 Comm: 8 7 4 2 Syscall: 1 7 4 PID: 3

In [143]:
def evaluate_critic(critic, dataset, limit=2000):
    """
    Evaluate LM-Critic performance using SyscallCriticDataset
    dataset: instance of SyscallCriticDataset
    """
    correct = 0
    total = 0
    results = {
        'normal': {'correct': 0, 'total': 0},
        'blatant': {'correct': 0, 'total': 0},
        'apt': {'correct': 0, 'total': 0}
    }
    
    for idx in tqdm(range(limit), desc=f"Evaluating sequences on first {limit}"):
        sample = dataset[idx]
        sequence = sample["sequence"]
        label = sample["label"]
        type_idx = sample["type"]
        
        # Get critic prediction (True = grammatical/normal, False = ungrammatical/anomalous)
        is_normal = critic.is_grammatical(sequence)
        
        # Get true label (normal is [0,0,1])
        true_normal = (label == [0,0,1]).all()
        # print(f"label: {label}")
        # print(f"true_normal: {true_normal}")
        # print(f"is_normal: {is_normal}")
        # Update counts
        if is_normal == true_normal:
            correct += 1
            if true_normal:
                results['normal']['correct'] += 1
            elif type_idx == 0:  # blatant
                results['blatant']['correct'] += 1
            elif type_idx == 1:  # apt
                results['apt']['correct'] += 1
        
        # Update totals
        total += 1
        if true_normal:
            results['normal']['total'] += 1
        elif type_idx == 0:
            results['blatant']['total'] += 1
        elif type_idx == 1:
            results['apt']['total'] += 1
        
        if total % 100 == 0:
            print(f"\nCurrent accuracy: {correct/total:.4f}")
            for category in results:
                if results[category]['total'] > 0:
                    cat_acc = results[category]['correct'] / results[category]['total']
                    print(f"{category} accuracy: {cat_acc:.4f}")
    
    return results

critic = SyscallLMCritic(model, tokenizer)
results = evaluate_critic(critic, sequence_dataset)

Evaluating sequences on first 2000:   5%|▌         | 100/2000 [02:31<47:40,  1.51s/it]


Current accuracy: 0.5800
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  10%|█         | 200/2000 [05:02<45:46,  1.53s/it]


Current accuracy: 0.5150
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  15%|█▌        | 300/2000 [07:42<47:34,  1.68s/it]  


Current accuracy: 0.5300
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  20%|██        | 400/2000 [10:19<40:20,  1.51s/it]


Current accuracy: 0.5125
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  25%|██▌       | 500/2000 [13:10<37:18,  1.49s/it]


Current accuracy: 0.4860
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  30%|███       | 600/2000 [15:54<38:31,  1.65s/it]


Current accuracy: 0.4783
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  35%|███▌      | 700/2000 [18:40<35:27,  1.64s/it]


Current accuracy: 0.4814
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  40%|████      | 800/2000 [21:23<32:25,  1.62s/it]


Current accuracy: 0.4738
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  45%|████▌     | 900/2000 [24:20<28:18,  1.54s/it]


Current accuracy: 0.4844
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  50%|█████     | 1000/2000 [27:13<26:25,  1.59s/it]


Current accuracy: 0.4900
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  55%|█████▌    | 1100/2000 [29:42<22:20,  1.49s/it]


Current accuracy: 0.4809
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  60%|██████    | 1200/2000 [32:12<20:10,  1.51s/it]


Current accuracy: 0.4792
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  65%|██████▌   | 1300/2000 [34:42<17:24,  1.49s/it]


Current accuracy: 0.4800
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  70%|███████   | 1400/2000 [37:14<15:00,  1.50s/it]


Current accuracy: 0.4779
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  75%|███████▌  | 1500/2000 [39:44<12:27,  1.49s/it]


Current accuracy: 0.4760
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  80%|████████  | 1600/2000 [42:15<10:28,  1.57s/it]


Current accuracy: 0.4813
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  85%|████████▌ | 1700/2000 [44:58<08:07,  1.62s/it]


Current accuracy: 0.4806
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  90%|█████████ | 1800/2000 [47:52<05:29,  1.65s/it]


Current accuracy: 0.4778
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000:  95%|█████████▌| 1900/2000 [50:35<02:35,  1.56s/it]


Current accuracy: 0.4847
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000


Evaluating sequences on first 2000: 100%|██████████| 2000/2000 [53:09<00:00,  1.59s/it]


Current accuracy: 0.4885
normal accuracy: 1.0000
blatant accuracy: 0.0000
apt accuracy: 0.0000



