In [1]:
# Contrastive Learning + Zero-Day Detection for IoMT IDS (Triplet Loss for Better Embedding)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, Lambda
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import classification_report, roc_auc_score
from scipy.spatial import distance
from sklearn.ensemble import IsolationForest
from sklearn.svm import OneClassSVM

# === Load Dataset ===
def load_dataset_from_structure(root_path):
    data = []
    for file in root_path.glob("**/*.csv"):
        try:
            df = pd.read_csv(file)
            if df.empty: continue
            category = file.parent.parent.name
            attack = file.parent.name
            label_class = 'Benign' if category.upper() == 'BENIGN' else 'Attack'
            df['category'] = category
            df['attack'] = attack
            df['class'] = label_class
            data.append(df)
        except Exception as e:
            print(f"[ERROR] Failed to read file {file}: {e}")
    return pd.concat(data, ignore_index=True) if data else pd.DataFrame()

train_root = Path('../../../Data/CICIoMT2024/train')
test_root = Path('../../../Data/CICIoMT2024/test')

train_df = load_dataset_from_structure(train_root)
test_df = load_dataset_from_structure(test_root)

df = pd.concat([train_df, test_df], ignore_index=True)

In [7]:
df['attack'].value_counts()

attack
DDOS UDP              1998026
DDOS ICMP             1887175
DDOS TCP               987063
DDOS SYN               974359
DOS UDP                704503
DOS SYN                540498
DOS ICMP               514724
DOS TCP                462480
BENIGN                 230339
DDOS CONNECT FLOOD     214952
PORT SCAN              106603
DOS PUBLISH FLOOD       52881
DDOS PUBLISH FLOOD      36039
OS SCAN                 20666
SPOOFING                17791
DOS CONNECT FLOOD       15904
MALFORMED DATA           6877
RECON VULSCAN            3207
PING SWEEP                926
Name: count, dtype: int64

In [2]:
# === Triplet Loss Function ===
def triplet_loss(margin=1.0):
    def loss(y_true, y_pred):
        anchor, positive, negative = y_pred[:, :32], y_pred[:, 32:64], y_pred[:, 64:]
        pos_dist = K.sum(K.square(anchor - positive), axis=1)
        neg_dist = K.sum(K.square(anchor - negative), axis=1)
        return K.mean(K.maximum(pos_dist - neg_dist + margin, 0.0))
    return loss


In [9]:
# === Prepare Data ===
df['attack'] = df['attack'].str.upper().str.strip()
df['class'] = df['class'].str.upper().str.strip()

print("🔎 Unique attack values:", df['attack'].unique())

zero_attack_label = 'DDOS UDP'  # Sebelumnya 'DDoS UDP', tidak cocok dengan data
benign_df = df[df['class'] == 'BENIGN'].copy()
zero_df = df[df['attack'] == zero_attack_label].copy()
attack_df = df[(df['class'] == 'ATTACK') & (df['attack'] != zero_attack_label)].copy()

triplet_size = min(len(benign_df), len(zero_df), len(attack_df))
if triplet_size == 0:
    raise ValueError("Triplet datasets are empty. Periksa kembali label 'BENIGN', 'DDoS UDP', atau struktur data.")

triplet_benign = benign_df.sample(n=triplet_size, random_state=42)
triplet_attack = attack_df.sample(n=triplet_size, random_state=42)
triplet_zero = zero_df.sample(n=triplet_size, random_state=42)

feature_cols = df.select_dtypes(include='number').columns.tolist()
triplet_benign = triplet_benign.dropna(subset=feature_cols)
triplet_attack = triplet_attack.dropna(subset=feature_cols)
triplet_zero = triplet_zero.dropna(subset=feature_cols)

valid_cols = list(set(feature_cols).intersection(
    triplet_benign.columns, triplet_attack.columns, triplet_zero.columns
))

print(f"[INFO] valid_cols: {len(valid_cols)} features")
print(f"[INFO] triplet_benign shape: {triplet_benign[valid_cols].shape}")
print(f"[INFO] triplet_attack shape: {triplet_attack[valid_cols].shape}")
print(f"[INFO] triplet_zero shape: {triplet_zero[valid_cols].shape}")

if not valid_cols:
    raise ValueError("Tidak ada fitur numerik yang cocok di antara triplet datasets. Periksa kembali struktur datanya.")

scaler = StandardScaler()
scaler.fit(df[valid_cols])

anchor = scaler.transform(triplet_benign[valid_cols])
positive = scaler.transform(triplet_attack[valid_cols])
negative = scaler.transform(triplet_zero[valid_cols])

X_triplet = np.concatenate([anchor, positive, negative], axis=1)
y_dummy = np.zeros((X_triplet.shape[0],))

🔎 Unique attack values: ['BENIGN' 'DDOS ICMP' 'DDOS SYN' 'DDOS TCP' 'DDOS UDP' 'DOS ICMP'
 'DOS SYN' 'DOS TCP' 'DOS UDP' 'DDOS CONNECT FLOOD' 'DDOS PUBLISH FLOOD'
 'DOS CONNECT FLOOD' 'DOS PUBLISH FLOOD' 'MALFORMED DATA' 'OS SCAN'
 'PING SWEEP' 'PORT SCAN' 'RECON VULSCAN' 'SPOOFING']
[INFO] valid_cols: 45 features
[INFO] triplet_benign shape: (230339, 45)
[INFO] triplet_attack shape: (230339, 45)
[INFO] triplet_zero shape: (230339, 45)


In [11]:
# === Build Triplet Network ===
def build_base_network(input_shape):
    inp = Input(shape=(input_shape,))
    x = Dense(128, activation='relu')(inp)
    x = Dropout(0.3)(x)
    x = Dense(64, activation='relu')(x)
    x = Dropout(0.3)(x)
    x = Dense(32, activation='linear')(x)
    return Model(inp, x)

input_shape = anchor.shape[1]
base_network = build_base_network(input_shape)

anchor_input = Input(shape=(input_shape,), name='anchor_input')
positive_input = Input(shape=(input_shape,), name='positive_input')
negative_input = Input(shape=(input_shape,), name='negative_input')

encoded_anchor = base_network(anchor_input)
encoded_positive = base_network(positive_input)
encoded_negative = base_network(negative_input)

merged_output = Lambda(lambda x: K.concatenate(x, axis=1))([encoded_anchor, encoded_positive, encoded_negative])
triplet_model = Model(inputs=[anchor_input, positive_input, negative_input], outputs=merged_output)

triplet_model.compile(loss=triplet_loss(margin=1.0), optimizer=Adam(0.001))
triplet_model.fit([anchor, positive, negative], y_dummy, batch_size=64, epochs=15, verbose=1)


Epoch 1/15
[1m3600/3600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 2ms/step - loss: 0.8214
Epoch 2/15
[1m3600/3600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 2ms/step - loss: 0.1253
Epoch 3/15
[1m3600/3600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 2ms/step - loss: 0.0954
Epoch 4/15
[1m3600/3600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 2ms/step - loss: 0.0830
Epoch 5/15
[1m3600/3600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 2ms/step - loss: 0.0817
Epoch 6/15
[1m3600/3600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 2ms/step - loss: 0.0772
Epoch 7/15
[1m3600/3600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 2ms/step - loss: 0.0758
Epoch 8/15
[1m3600/3600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 2ms/step - loss: 0.0762
Epoch 9/15
[1m3600/3600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 2ms/step - loss: 0.0756
Epoch 10/15
[1m3600/3600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1

<keras.src.callbacks.history.History at 0x2104fdd33b0>

In [12]:
# === Embedding dan Deteksi Anomali ===
encoder_model = base_network
embedding_benign = encoder_model.predict(scaler.transform(benign_df[valid_cols]))
embedding_zero = encoder_model.predict(scaler.transform(zero_df[valid_cols]))

# Isolation Forest
iso = IsolationForest(contamination=0.05, random_state=42)
iso.fit(embedding_benign)
iso_pred = [1 if p == -1 else 0 for p in np.concatenate([iso.predict(embedding_benign), iso.predict(embedding_zero)])]

# One-Class SVM
svm = OneClassSVM(kernel='rbf', gamma='auto')
svm.fit(embedding_benign)
svm_pred = [1 if p == -1 else 0 for p in np.concatenate([svm.predict(embedding_benign), svm.predict(embedding_zero)])]

[1m7199/7199[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 714us/step
[1m62439/62439[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m47s[0m 758us/step


In [13]:
# Mahalanobis Distance
mean_vec = np.mean(embedding_benign, axis=0)
cov_inv = np.linalg.pinv(np.cov(embedding_benign, rowvar=False))
d_mahal_benign = [distance.mahalanobis(x, mean_vec, cov_inv) for x in embedding_benign]
d_mahal_zero = [distance.mahalanobis(x, mean_vec, cov_inv) for x in embedding_zero]
thresh = np.percentile(d_mahal_benign, 95)
mahal_pred = [1 if d > thresh else 0 for d in d_mahal_benign + d_mahal_zero]

In [14]:
# Ensemble
ensemble_pred = [1 if (s == 1 and m == 1) else 0 for s, m in zip(svm_pred, mahal_pred)]
y_true = [0] * len(embedding_benign) + [1] * len(embedding_zero)

print("\n[Ensemble Detection Evaluation]")
print(classification_report(y_true, ensemble_pred, target_names=['Benign', zero_attack_label]))
print("ROC-AUC:", roc_auc_score(y_true, ensemble_pred))


[Ensemble Detection Evaluation]
              precision    recall  f1-score   support

      Benign       1.00      0.95      0.97    230339
    DDOS UDP       0.99      1.00      1.00   1998026

    accuracy                           0.99   2228365
   macro avg       1.00      0.98      0.99   2228365
weighted avg       0.99      0.99      0.99   2228365

ROC-AUC: 0.9756061082811225
