In [None]:
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score
import matplotlib.pyplot as plt

In [2]:
model = nn.Sequential(
    nn.Linear(42, 64),
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Linear(64, 64),
    nn.ReLU(),
    nn.Linear(64, 24),
)

In [3]:
def prep_df(df, binary=False) -> pd.DataFrame:
    def cat_col(df, cols):
        ret = df.copy()
        for c in cols:
            ret[c] = df[c].astype('category').cat.codes
        return ret
    
    ret = cat_col(df, ["protocol_type", "service", "flag", "attack"])
    if binary:
        ret['attack'] = np.where(ret['attack'] == 11, 0, 1)
    return ret

In [6]:
df = prep_df(pd.concat([pd.read_csv("./data/KDD_small.csv"), pd.read_csv("./data/cryptomining_kdd.csv")]))
df

Unnamed: 0.1,Unnamed: 0,duration,protocol_type,service,flag,src_bytes,dst_bytes,land,wrong_fragment,urgent,...,dst_host_srv_count,dst_host_same_srv_rate,dst_host_diff_srv_rate,dst_host_same_src_port_rate,dst_host_srv_diff_host_rate,dst_host_serror_rate,dst_host_srv_serror_rate,dst_host_rerror_rate,dst_host_srv_rerror_rate,attack
0,0.0,0,2,46,9,146,0,0,0,0,...,1,0.00,0.60,0.88,0.00,0.00,0.00,0.0,0.00,12
1,1.0,0,1,51,5,0,0,0,0,0,...,26,0.10,0.05,0.00,0.00,1.00,1.00,0.0,0.00,10
2,2.0,0,1,25,9,232,8153,0,0,0,...,255,1.00,0.00,0.03,0.04,0.03,0.01,0.0,0.01,12
3,3.0,0,1,25,9,199,420,0,0,0,...,255,1.00,0.00,0.00,0.00,0.00,0.00,0.0,0.00,12
4,4.0,0,1,51,1,0,0,0,0,0,...,19,0.07,0.07,0.00,0.00,0.00,0.00,1.0,1.00,10
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
596,,0,2,13,5,0,0,0,0,0,...,100,0.00,0.00,0.00,0.00,1.00,0.00,0.0,0.00,2
597,,0,2,13,5,0,0,0,0,0,...,100,0.00,0.00,0.00,0.00,1.00,0.00,0.0,0.00,2
598,,0,2,13,5,0,0,0,0,0,...,100,0.00,0.00,0.00,0.00,1.00,0.00,0.0,0.00,2
599,,0,2,13,5,0,0,0,0,0,...,100,0.00,0.00,0.00,0.00,1.00,0.00,0.0,0.00,2


In [7]:
attack_types = [
    'back.',            # Backdoor: attacker installs or enables secret remote access (a "backdoor") so they can later control the machine without normal authentication. Look for unusual listening services, unknown user accounts, or persistent startup entries.
    'buffer_overflow.', # Buffer overflow: attacker crafts input that overruns memory buffers to overwrite program state and execute arbitrary code (often leads to remote code execution). Signs: crashes, weird process behavior, or exploit attempts in application logs.
    'cryptomining',     # Testing for 
    'ftp_write.',       # FTP write: attacker authenticates or exploits an FTP server to upload or modify files (webshells, malware, or data tampering). Detection: unexpected new/modified files on FTP dirs or file uploads from unusual IPs.
    'guess_passwd.',    # Guessing passwords (brute force): automated attempts to discover credentials by trying many username/password combos. Detection: lots of failed login attempts, many distinct sources, or rapid repeated attempts from one IP.
    'imap.',            # IMAP attacks: attempts targeting IMAP mail servers (credential stuffing, brute force, or exploiting server bugs) to read/steal mail or use mail service as pivot. Watch for abnormal IMAP logins, unusual IPs, or excessive mailbox access.
    'ipsweep.',         # IP sweep: scanner probes a range of IP addresses to find hosts that respond (basic reconnaissance). Detection: many connection attempts across consecutive IPs and short time windows.
    'land.',            # LAND DoS: forged packets with the same source and destination IP/port (src==dst) that confuse some TCP/IP stacks and may crash or hang the target. Detection: packets where source==destination or peculiar TCP resets/crashes.
    'loadmodule.',      # Load module attack: attacker attempts to upload and load a malicious kernel/module or server module (executes code in privileged context). Look for suspicious module loads, new binaries, or privilege-escalation attempts.
    'multihop.',        # Multihop (proxying/pivoting): attacker routes access via one or more compromised machines to hide origin and reach otherwise inaccessible systems. Indicators: strange relay connections, tunnels, or unusual intermediate hosts in logs.
    'neptune.',         # Neptune (SYN flood): classic TCP SYN flood that sends many connection requests without completing handshakes, exhausting connection tables and causing DoS. Symptoms: many half-open connections and resource exhaustion on the target.
    'nmap.',            # Nmap scanning: active port/service/OS scanning with the nmap tool (reconnaissance to discover services and vulnerabilities). Detection: diverse ports probed from same IP, fingerprinting patterns or TTL/packet patterns matching nmap.
    'normal.',          # Normal traffic: legitimate, benign network activity — not an attack. Helpful as baseline for anomaly detection.
    'perl.',            # Perl/CGI script attacks: exploitation of Perl-based CGI scripts or server-side Perl apps to run arbitrary commands or upload malware. Watch for suspicious HTTP requests invoking CGI scripts or file writes from webserver processes.
    'phf.',             # phf CGI exploit: a historical web CGI vulnerability (phf) where specially crafted requests could execute commands on the server. Detection: web requests with unusual query payloads to phf or unexpected command output/files.
    'pod.',             # POD (Ping of Death / Packet of Death): sending oversized or malformed ICMP packets that cause older systems to crash or reboot. Symptoms: malformed ICMP traffic and sudden crashes/reboots after such packets.
    'portsweep.',       # Port sweep: scanning many ports across one or more hosts to find available services (more focused than IPSweep). Detect by seeing many connection attempts to different ports from the same source.
    'rootkit.',         # Rootkit: stealthy malware installed at high (often kernel) privilege to hide attacker presence and maintain persistent control. Look for hidden processes, altered system binaries, unusual kernel modules, or tampered audit logs.
    'satan.',           # SATAN scanner activity: use of the SATAN vulnerability scanner to enumerate known weaknesses (historical tool similar to Nessus). Detection: a set of targeted probes for known vulnerabilities and misconfigurations.
    'smurf.',           # Smurf attack: amplification DoS using spoofed ICMP echo requests sent to broadcast addresses which then flood the victim with replies. Detection: many ICMP replies from broadcast networks with spoofed source IPs.
    'spy.',             # Spy / sniffing activity: passive capture of network traffic or installation of spyware to steal credentials/data. Indicators: suspicious promiscuous-mode network interfaces, odd packet captures, or data exfiltration.
    'teardrop.',        # Teardrop: DoS that sends overlapping or malformed IP fragments which crash or confuse vulnerable IP reassembly code. Symptoms: fragmented packets with invalid offsets and crashes/restarts on affected systems.
    'warezclient.',     # Warez client behavior: client machines contacting servers to download pirated software — in IDS datasets this often signals suspicious P2P or illicit file-sharing activity which can correlate with other threats. Look for repeated connections to known warez servers or P2P ports.
    'warezmaster.'      # Warez master: server/operator coordinating the distribution of pirated software (the source/origin of warez). Detection: servers receiving many upload requests or hosting many illicit files and responding to many download clients.
]


In [None]:
def create_model(df=None, train_test=None, random_state=47, show_scores=False, binary=False):
    nonlocal model
    
    if type(df) == pd.DataFrame:
        X, y = df.drop(columns=["attack"]), df['attack']
        # Split data into training and testing sets
        X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=random_state)

    if train_test != None:
        X_train, X_test, y_train, y_test = train_test

    # Train the model
    model.fit(X_train, y_train)

    # Make predictions
    y_pred = model.predict(X_test)
    
    if show_scores:

        accuracy = accuracy_score(y_test, y_pred)
        print(f"\t\tAccuracy: {accuracy}")

        precision = precision_score(y_test, y_pred) if binary else precision_score(y_test, y_pred, average='micro')
        print(f"\t\tPrecision: {precision}")

    
    # Identify misclassified samples
    misclassified = (y_test != y_pred)

    # Count failures per true class index
    fail_counts = np.bincount(y_test[misclassified], minlength=len(attack_types))

    ret_fails = { attack_types[i]: fail_counts[i]
                for i in range(len(fail_counts)) if fail_counts[i] > 0 }

        
    return (y_test, y_pred, ret_fails)

In [None]:
def test_learn_time(df, df_crypt, additive_range=range(1, 101, 10), split_frac=.8, new_attack="normal.", debug=False, stop_at_data=False):

    crypt_acc = []
    model_acc = []
    model_pre = []
    model_f1 = []

    for d in additive_range:
        crypt_acc_s = []
        model_acc_s = []
        model_pre_s = []
        model_f1_s = []

        if debug:
            print(f"   {d=}")

        # Base split
        # df_train = df[:split]
        # df_test = df[split:]

        df_train, df_test = train_test_split(df, train_size=float(split_frac))

        df_crypt_train, df_crypt_test = train_test_split(
            df_crypt, train_size=d)

        # Add d crypt samples to train
        # train_data = pd.concat([df_train, df_crypt.iloc[:d]])
        train_data = pd.concat([df_train, df_crypt_train])
        X_train = train_data.drop(columns=['attack'])
        y_train = train_data['attack']

        # Test data includes all remaining crypt points
        # test_data = pd.concat([df_test, df_crypt.iloc[test_start:]])
        test_data = pd.concat([df_test, df_crypt_test])
        X_test = test_data.drop(columns=['attack'])
        y_test_true = test_data['attack']

        if stop_at_data:
            return X_train, X_test, y_train, y_test_true

        for seed in range(4):

            if debug:
                print(f"\tStarting Model {seed}")

            # Do NOT overwrite y_test_true
            _, y_pred, _ = create_model(
                train_test=[X_train, X_test, y_train, y_test_true],
                num_estimators=num_est,
                show_scores=debug,
                random_state=seed
            )

            new_attack_num = attack_types.index(new_attack)

            # Create binary labels: new_attack vs everything else
            y_test_binary = (y_test_true == new_attack_num).astype(int)
            y_pred_binary = (y_pred == new_attack_num).astype(int)

            if not (new_attack_num in y_pred_binary and new_attack_num in y_test_binary):
                new_attack_recall = recall_score(y_test_binary, y_pred_binary, zero_division=0)
                
                if debug: 
                    print(f"\t{new_attack_recall=}")

                crypt_acc_s.append(new_attack_recall)
                
            model_acc_s.append(accuracy_score(y_test_true, y_pred))
            model_pre_s.append(precision_score(
                y_test_true, y_pred, average='weighted', zero_division=0))
            model_f1_s.append(f1_score(y_test_true, y_pred,
                              average='weighted', zero_division=1))

        crypt_acc.append(np.mean(crypt_acc_s))
        model_acc.append(np.mean(model_acc_s))
        model_pre.append(np.mean(model_pre_s))
        model_f1.append(np.mean(model_f1_s))

    return crypt_acc, model_acc, model_pre, model_f1

In [None]:
def total_testing(df_total, new_range=range(1, 20), debug=False):
    results = []
    for c in range(24):
        # if c >= 3:
        #     break

        atk_name = attack_types[c]
        df_other = df_total[df_total['attack'] != c]
        df_new = df_total[df_total['attack'] == c]
        if len(df_new) < 500:
            print(f"{atk_name} skipped")
            continue
        results.append((atk_name, test_learn_time(df=df_other, df_crypt=df_new, additive_range=new_range, num_est=10, new_attack=atk_name, debug=debug)))
        print(f"{atk_name} testing done")
    return results 