In [8]:
import pickle
import re
from datetime import datetime

import pandas as pd
from pyspark.ml.feature import StringIndexer
from pyspark.sql import SparkSession
from sklearn.preprocessing import LabelEncoder

In [9]:
DIRECTORY = '..'
# DIRECTORY = 'generated_data'

In [10]:
csv_files = [
    "normal-traffic",
    "port-scanning",
    "ddos-tcp-syn-flood"
]

spark = SparkSession.builder \
    .appName("Prepare and sort logs") \
    .getOrCreate()


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/06/14 18:09:45 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Wybieramy sobie kolumny zawierające istotne iformacje. Można dodać więcej ale wtedy trzeba pamiętać o noramlizacji w kolejnej komórce.

In [11]:
selected_columns = [
    "frame-time",
    "arp-opcode",
    "arp-hw-size",
    "ip-src_host",
    "ip-dst_host",
    "tcp-ack",
    "tcp-ack_raw",
    "tcp-connection-fin",
    "tcp-connection-rst",
    "tcp-connection-syn",
    "tcp-connection-synack",
    "tcp-dstport",
    "tcp-flags_index",
    "tcp-flags-ack",
    "tcp-len",
    "tcp-seq",
    "tcp-srcport",
    "udp-port",
    "udp-stream",
    "udp-time_delta",
    "dns-qry-name",
    "dns-qry-name-len_index",
    "dns-qry-qu_index",
    "dns-qry-type",
    "dns-retransmission",
    "dns-retransmit_request",
    "dns-retransmit_request_in",
    "mqtt-conack-flags_index",
    "mqtt-conflag-cleansess",
    "mqtt-conflags_index",
    "mqtt-hdrflags_index",
    "mqtt-len",
    "mqtt-msg_index",
    "mqtt-msgtype",
    "mqtt-proto_len",
    "mqtt-protoname_index",
    "mqtt-topic_index",
    "mqtt-topic_len",
    "mqtt-ver",
    "Attack_type"
]

Iterujemy po wczytanych ramkach, zamieniamy nazwy kolumn na takie bez kropek i normalizujemy/kodujemy nieliczbowe kolumny (oprócz timestampów, ta kolumna jest modyfikowana później). Odchudzone dane zapisujemy do katalogu `preprocessed_data`

In [12]:
def timestamp_to_epoch(timestamp):
    dt = datetime.fromisoformat(f'2024-06-05 {str(timestamp).strip().split()[1]}')
    return dt.timestamp()


In [None]:
from sklearn.preprocessing import MinMaxScaler

def normalize_data(pandas_df):
    columns_to_normalize = ['tcp-ack_raw', 'tcp-ack', 'tcp-dstport', 'tcp-len', 'tcp-seq', 'tcp-srcport']
    scaler = MinMaxScaler()
    pandas_df[columns_to_normalize] = scaler.fit_transform(pandas_df[columns_to_normalize])
    return pandas_df

In [13]:
for file in csv_files:
    
    df = spark.read.csv(DIRECTORY + '/' + file + '.csv', header=True, inferSchema=True)

    for col_name in df.columns:
        new_col_name = re.sub(r'\.', '-', col_name)
        df = df.withColumnRenamed(col_name, new_col_name)
    
    tcp_flags_indexer = StringIndexer(inputCol="tcp-flags", outputCol="tcp-flags_index")
    indexed_df = tcp_flags_indexer.fit(df).transform(df)

    dns_qry_name_len_indexer = StringIndexer(inputCol="dns-qry-name-len", outputCol="dns-qry-name-len_index")
    indexed_df = dns_qry_name_len_indexer.fit(indexed_df).transform(indexed_df)
    dns_qry_qu_indexer = StringIndexer(inputCol="dns-qry-qu", outputCol="dns-qry-qu_index")
    indexed_df = dns_qry_qu_indexer.fit(indexed_df).transform(indexed_df)

    mqtt_conack_flags_indexer = StringIndexer(inputCol="mqtt-conack-flags", outputCol="mqtt-conack-flags_index")
    indexed_df = mqtt_conack_flags_indexer.fit(indexed_df).transform(indexed_df)
    mqtt_conflags = StringIndexer(inputCol="mqtt-conflags", outputCol="mqtt-conflags_index")
    indexed_df = mqtt_conflags.fit(indexed_df).transform(indexed_df)

    mqtt_hdrflags = StringIndexer(inputCol="mqtt-hdrflags", outputCol="mqtt-hdrflags_index")
    indexed_df = mqtt_hdrflags.fit(indexed_df).transform(indexed_df)
    mqtt_msg = StringIndexer(inputCol="mqtt-msg", outputCol="mqtt-msg_index")
    indexed_df = mqtt_msg.fit(indexed_df).transform(indexed_df)

    mqtt_protoname = StringIndexer(inputCol="mqtt-protoname", outputCol="mqtt-protoname_index")
    indexed_df = mqtt_protoname.fit(indexed_df).transform(indexed_df)

    mqtt_topic = StringIndexer(inputCol="mqtt-topic", outputCol="mqtt-topic_index")
    indexed_df = mqtt_topic.fit(indexed_df).transform(indexed_df)

    pandas_df = indexed_df.select(selected_columns).toPandas()
    pandas_df = pandas_df.drop_duplicates()
    pandas_df = normalize_data(pandas_df)

    all_ips = pd.concat([pandas_df["ip-src_host"], pandas_df["ip-dst_host"]]).unique()
    label_encoder = LabelEncoder()
    label_encoder.fit(all_ips)
    pandas_df["ip-src_host"] = label_encoder.transform(pandas_df["ip-src_host"])
    pandas_df["ip-dst_host"] = label_encoder.transform(pandas_df["ip-dst_host"])
    pandas_df['frame-time'] = pandas_df['frame-time'].apply(timestamp_to_epoch)

    pandas_df = pandas_df.reindex(sorted(pandas_df.columns), axis=1)
    pandas_df.to_csv(f'{DIRECTORY}/{file}-preprocessed.csv', index=False)

24/06/14 18:12:00 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
                                                                                

In [14]:
spark.stop()

Wczytujemy zapisane pliki csv i tworzymy próbki z danymi, gdzie jedna próbka X to lista zawierająca kolejne 32 logi gdzie od każdego timestampa został odjęty timestamp pierwszego loga z listy (w ten sposób timestampy są niewielkimi wartościami liczbowymi a jednocześnie przechowują informację o odległości pomiędzy kolejnymi logami), a próbka Y to pojedynczy numer określający typ ataku/ruchu normalnego dla zagregowanych logów.

In [15]:
def logs_to_series(df, logs_per_bucket):
    del df['Attack_type']
    buckets = []

    for i in range(0, df.shape[0], logs_per_bucket):
        if df.shape[0] >= i + logs_per_bucket:
            bucket = df.iloc[i:i + logs_per_bucket]
            bucket['frame-time'] = bucket['frame-time'] - bucket['frame-time'].iloc[0]
            buckets.append(bucket)

    return buckets

In [16]:
encoded_attacks = {
    "normal-traffic": 0,
    "port-scanning": 1,
    "ddos-tcp-syn-flood": 2
}
x_data = []
y_data = []
for file in csv_files:
    df = pd.read_csv(f'{DIRECTORY}/{file}-preprocessed.csv')
    log_series = logs_to_series(df, 32)
    x_data.extend(log_series)
    y_data.extend([encoded_attacks.get(file)] * len(log_series))

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  bucket['frame-time'] = bucket['frame-time'] - bucket['frame-time'].iloc[0]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  bucket['frame-time'] = bucket['frame-time'] - bucket['frame-time'].iloc[0]
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  bucket['frame-time'] = bucket['frame-time'] - bucket['f

Przed treningem modeli należy jeszcze pomieszać próbki z danymi oraz podzielić na zbiory treningowe i testowe. W sumie dobrze by też było dodać jakiś padding dla przypadków gdzie jednak próbka nie ma 32 logów.

In [17]:
output_path = f'{DIRECTORY}/processed_data.pkl'
print("Writing log series into ", output_path)
with open(output_path, 'wb') as f:
    pickle.dump((x_data, y_data), f)

Writing log series into  ../processed_data.pkl


In [19]:
from datetime import datetime, timedelta
import os
import random

import numpy as np
import pandas as pd

TRAFFIC_TYPES = {
    "normal-traffic": "Normal",
    "port-scanning": "Port_Scanning",
    "ddos-tcp-syn-flood": "DDoS_TCP"
}


def generate_logs_for_traffic(traffic_type, logs_count):
    match traffic_type:
        case 'Normal':
            return generate_logs_for_normal_traffic(logs_count)
        case 'Port_Scanning':
            return generate_logs_for_port_scanning(logs_count)
        case 'DDoS_TCP':
            return generate_logs_for_ddos_tcp(logs_count)


def generate_logs_for_normal_traffic(logs_count):
    logs = []
    previous_timestamp = datetime.now().astimezone()
    while len(logs) <= logs_count:
        logs.append(generate_normal_traffic_log(previous_timestamp))
        previous_timestamp = datetime.strptime(logs[-1][0], '%Y-%m-%dT%H:%M:%S.%f%z')
    return logs


def generate_normal_traffic_log(previous_timestamp):
    frame_time = (previous_timestamp + timedelta(milliseconds=random.randint(5, 20))).strftime('%Y-%m-%dT%H:%M:%S.%f%z')
    arp_opcode = 0.0
    arp_hw_siz = 0.0
    ip_src_host = draw_ip_address()
    ip_dst_host = draw_ip_address()
    tcp_ack = np.random.choice([0.0, 1.0, 5.0, 6.0, 15.0, 56.0, 59.0], p=[0.13, 0.24, 0.13, 0.19, 0.06, 0.06, 0.19])
    tcp_ack_raw = 0.0 if np.random.choice([0.0, 1.0], p=[0.2, 0.8]) == 0.0 else random.randint(9749767.0, 4292762365.0)
    tcp_connection_fin = np.random.choice([0.0, 1.0], p=[0.88, 0.12])
    tcp_connection_rst = np.random.choice([0.0, 1.0], p=[0.88, 0.12])
    tcp_connection_syn = np.random.choice([0.0, 1.0], p=[0.92, 0.08])
    tcp_connection_synack = np.random.choice([0.0, 1.0], p=[0.93, 0.07])
    tcp_dstport = random.randint(51173, 65156)
    tcp_srcport = random.randint(51173, 65156)
    tcp_len = np.random.choice([0.0, 2.0, 4.0, 14.0, 41.0], p=[0.76, 0.06, 0.06, 0.06, 0.06])
    tcp_seq = np.random.choice([0.0, 1.0, 5.0, 6.0, 15.0, 56.0, 59.0], p=[0.13, 0.24, 0.13, 0.19, 0.06, 0.06, 0.19])
    tcp_flags = np.random.choice(
        ['0.0', '0x00000002', '0x00000004', '0x00000010', '0x00000011', '0x00000012', '0x00000018', '0x00000019'], 
        p=[0.01, 0.06, 0.12, 0.44, 0.06, 0.07, 0.18, 0.06]
    )
    tcp_flags_ack = np.random.choice([0.0, 1.0], p=[0.24, 0.76])
    udp_port = 0.0
    udp_stream = 0.0
    udp_time_delta = 0.0
    dns_qry_name = 0.0
    dns_qry_name_len = 0
    dns_qry_qu = 0
    dns_qry_type = 0.0
    dns_retransmission = 0.0
    dns_retransmit_request = 0.0
    dns_retransmit_request_in = 0.0
    mqtt_conack_flags = np.random.choice(['0', '0x00000000'], p=[0.93, 0.07])
    mqtt_conflag_cleansess = np.random.choice([0.0, 1.0], p=[0.94, 0.06])
    mqtt_conflags = np.random.choice(['0', '0x00000002'], p=[0.93, 0.07])
    mqtt_hdrflags = np.random.choice(
        ['0.0', '0x00000010', '0x00000020', '0x00000030', '0x000000e0'], 
        p=[0.76, 0.06, 0.06, 0.06, 0.06]
    )
    mqtt_len = np.random.choice([0, 0.0, 2.0, 12.0, 39.0], p=[0.76, 0.06, 0.06, 0.06, 0.06])
    mqtt_msg = 0
    mqtt_msgtype = np.random.choice([0.0, 1.0, 2.0, 3.0, 14.0], p=[0.76, 0.06, 0.06, 0.06, 0.06])
    mqtt_proto_len = np.random.choice([0, 4.0], p=[0.94, 0.06])
    mqtt_protoname = np.random.choice(['0', 'MQTT'], p=[0.94, 0.06])
    mqtt_topic = 0
    mqtt_topic_len = 0.0
    mqtt_ver = np.random.choice([0, 4.0], p=[0.94, 0.06])

    return [frame_time, arp_opcode, arp_hw_siz, ip_src_host, ip_dst_host, tcp_ack, tcp_ack_raw,
            tcp_connection_fin, tcp_connection_rst, tcp_connection_syn, tcp_connection_synack,
            tcp_srcport, tcp_dstport, tcp_flags, tcp_flags_ack, tcp_len, tcp_seq, udp_port,
            udp_stream, udp_time_delta, dns_qry_name, dns_qry_name_len, dns_qry_qu, dns_qry_type, 
            dns_retransmission, dns_retransmit_request, dns_retransmit_request_in, mqtt_conack_flags, 
            mqtt_conflag_cleansess, mqtt_conflags, mqtt_hdrflags, mqtt_len, mqtt_msg, mqtt_msgtype, 
            mqtt_proto_len, mqtt_protoname, mqtt_topic, mqtt_topic_len, mqtt_ver, 'Normal']


def generate_logs_for_port_scanning(logs_count):
    logs = []
    hacker_ip = draw_ip_address()
    victim_ip = draw_ip_address()
    victim_port = 1000
    previous_timestamp = datetime.now().astimezone()
    while len(logs) <= logs_count:
        logs.extend(generate_port_scanning_log_pair(previous_timestamp, hacker_ip, victim_ip, victim_port))
        previous_timestamp = datetime.strptime(logs[-1][0], '%Y-%m-%dT%H:%M:%S.%f%z')
        victim_port += 1
    return logs


def generate_port_scanning_log_pair(previous_timestamp, hacker_ip, victim_ip, victim_port):
    frame_time = (previous_timestamp + timedelta(milliseconds=random.randint(1, 5))).strftime('%Y-%m-%dT%H:%M:%S.%f%z')
    hacker_port = 80
    arp_opcode = 0.0
    arp_hw_siz = 0.0
    tcp_ack = random.randint(341294.0, 2147250934.0)
    tcp_connection_fin = 0.0
    tcp_connection_synack = 0.0
    tcp_len = 0.0
    udp_port = 0.0
    udp_stream = 0.0
    udp_time_delta = 0.0
    dns_qry_name = 0.0
    dns_qry_name_len = 0.0
    dns_qry_qu = 0.0
    dns_qry_type = 0.0
    dns_retransmission = 0.0
    dns_retransmit_request = 0.0
    dns_retransmit_request_in = 0.0
    mqtt_conack_flags = 0.0
    mqtt_conflag_cleansess = 0.0
    mqtt_conflags = 0.0
    mqtt_hdrflags = 0.0
    mqtt_len = 0.0
    mqtt_msg = 0.0
    mqtt_msgtype = 0.0
    mqtt_proto_len = 0.0
    mqtt_protoname = 0.0
    mqtt_topic = 0.0
    mqtt_topic_len = 0.0
    mqtt_ver = 0.0
    traffic_type = 'Port_Scanning'

    return [
            [
                frame_time, arp_opcode, arp_hw_siz, hacker_ip, victim_ip, tcp_ack, tcp_ack, tcp_connection_fin,
                0.0, 1.0, tcp_connection_synack, hacker_port, victim_port, '0x00000014', 1.0, tcp_len, 1.0,
                udp_port, udp_stream, udp_time_delta, dns_qry_name, dns_qry_name_len, dns_qry_qu, dns_qry_type, 
                dns_retransmission, dns_retransmit_request, dns_retransmit_request_in, mqtt_conack_flags, 
                mqtt_conflag_cleansess, mqtt_conflags, mqtt_hdrflags, mqtt_len, mqtt_msg, mqtt_msgtype, 
                mqtt_proto_len, mqtt_protoname, mqtt_topic, mqtt_topic_len, mqtt_ver, traffic_type
            ],
            [
                frame_time, arp_opcode, arp_hw_siz, victim_ip, hacker_ip, 1.0, tcp_ack, tcp_connection_fin,
                1.0, 0.0, tcp_connection_synack, victim_port, hacker_port, '0x00000002', 0.0, tcp_len, 0.0,
                udp_port, udp_stream, udp_time_delta, dns_qry_name, dns_qry_name_len, dns_qry_qu, dns_qry_type, 
                dns_retransmission, dns_retransmit_request, dns_retransmit_request_in, mqtt_conack_flags, 
                mqtt_conflag_cleansess, mqtt_conflags, mqtt_hdrflags, mqtt_len, mqtt_msg, mqtt_msgtype, 
                mqtt_proto_len, mqtt_protoname, mqtt_topic, mqtt_topic_len, mqtt_ver,traffic_type
            ]
        ]


def generate_logs_for_ddos_tcp(logs_count):
    logs = []
    victim_ip = draw_ip_address()
    hacker_port = 30000
    previous_timestamp = datetime.now().astimezone()
    while len(logs) < logs_count:
        logs.extend(generate_ddos_log_pair(previous_timestamp, victim_ip, hacker_port))
        previous_timestamp = datetime.strptime(logs[-1][0], '%Y-%m-%dT%H:%M:%S.%f%z')
        hacker_port += 1
    return logs


def generate_ddos_log_pair(previous_timestamp, victim_ip, hacker_port):
    if random.randint(1, 10) == 1:
        frame_time = (previous_timestamp + timedelta(milliseconds=1)).strftime('%Y-%m-%dT%H:%M:%S.%f%z')
    else:
        frame_time = previous_timestamp.strftime('%Y-%m-%dT%H:%M:%S.%f%z')
    hacker_ip = draw_ip_address()
    victim_port = 80
    arp_opcode = 0.0
    arp_hw_siz = 0.0
    tcp_ack = random.randint(341294.0, 2147250934.0)
    tcp_connection_fin = 0.0
    tcp_connection_synack = 0.0
    udp_port = 0.0
    udp_stream = 0.0
    udp_time_delta = 0.0
    dns_qry_name = 0.0
    dns_qry_name_len = 0.0
    dns_qry_qu = 0.0
    dns_qry_type = 0.0
    dns_retransmission = 0.0
    dns_retransmit_request = 0.0
    dns_retransmit_request_in = 0.0
    mqtt_conack_flags = 0.0
    mqtt_conflag_cleansess = 0.0
    mqtt_conflags = 0.0
    mqtt_hdrflags = 0.0
    mqtt_len = 0.0
    mqtt_msg = 0.0
    mqtt_msgtype = 0.0
    mqtt_proto_len = 0.0
    mqtt_protoname = 0.0
    mqtt_topic = 0.0
    mqtt_topic_len = 0.0
    mqtt_ver = 0.0
    traffic_type = 'DDoS_TCP'

    log_pair = [
        [
            frame_time, arp_opcode, arp_hw_siz, hacker_ip, victim_ip, tcp_ack, tcp_ack, tcp_connection_fin,
            0.0, 1.0, tcp_connection_synack, hacker_port, victim_port, '0x00000002', 0.0, 120.0, 0.0,
            udp_port, udp_stream, udp_time_delta, dns_qry_name, dns_qry_name_len, dns_qry_qu, dns_qry_type,
            dns_retransmission, dns_retransmit_request, dns_retransmit_request_in, mqtt_conack_flags,
            mqtt_conflag_cleansess, mqtt_conflags, mqtt_hdrflags, mqtt_len, mqtt_msg, mqtt_msgtype,
            mqtt_proto_len, mqtt_protoname, mqtt_topic, mqtt_topic_len, mqtt_ver, traffic_type
        ],
        [
            frame_time, arp_opcode, arp_hw_siz, victim_ip, hacker_ip, 121.0, tcp_ack, tcp_connection_fin,
            1.0, 0.0, tcp_connection_synack, victim_port, hacker_port, '0x00000014', 1.0, 0.0, 1.0,
            udp_port, udp_stream, udp_time_delta, dns_qry_name, dns_qry_name_len, dns_qry_qu, dns_qry_type,
            dns_retransmission, dns_retransmit_request, dns_retransmit_request_in, mqtt_conack_flags,
            mqtt_conflag_cleansess, mqtt_conflags, mqtt_hdrflags, mqtt_len, mqtt_msg, mqtt_msgtype,
            mqtt_proto_len, mqtt_protoname, mqtt_topic, mqtt_topic_len, mqtt_ver, traffic_type
        ]
    ]

    if random.randint(1, 10) == 1:
        log_pair.pop()

    return log_pair


def draw_ip_address():
    return str(random.randint(0, 255)) + '.' \
        + str(random.randint(0, 255)) + '.' \
        + str(random.randint(0, 255)) + '.' \
        + str(random.randint(0, 255))


OUTDIR = './generated_data'
if not os.path.exists(OUTDIR):
    os.mkdir(OUTDIR)

columns = [
    "frame-time",
    "arp-opcode",
    "arp-hw-size",
    "ip-src_host",
    "ip-dst_host",
    "tcp-ack",
    "tcp-ack_raw",
    "tcp-connection-fin",
    "tcp-connection-rst",
    "tcp-connection-syn",
    "tcp-connection-synack",
    "tcp-srcport",
    "tcp-dstport",
    "tcp-flags",
    "tcp-flags-ack",
    "tcp-len",
    "tcp-seq",
    "udp-port",
    "udp-stream",
    "udp-time_delta",
    "dns-qry-name",
    "dns-qry-name-len",
    "dns-qry-qu",
    "dns-qry-type",
    "dns-retransmission",
    "dns-retransmit_request",
    "dns-retransmit_request_in",
    "mqtt-conack-flags",
    "mqtt-conflag-cleansess",
    "mqtt-conflags",
    "mqtt-hdrflags",
    "mqtt-len",
    "mqtt-msg",
    "mqtt-msgtype",
    "mqtt-proto_len",
    "mqtt-protoname",
    "mqtt-topic",
    "mqtt-topic_len",
    "mqtt-ver",
    "Attack_type"
]
    
for filename, traffic_type in TRAFFIC_TYPES.items():
    generated_logs = generate_logs_for_traffic(traffic_type, 10000)
    
    data = pd.DataFrame(generated_logs, columns=columns)
    print(f'Saving data to file: {OUTDIR}/{filename}.csv')
    data.to_csv(f'{OUTDIR}/{filename}.csv', index=False)

Saving data to file: ./generated_data/normal-traffic.csv
Saving data to file: ./generated_data/port-scanning.csv
Saving data to file: ./generated_data/ddos-tcp-syn-flood.csv


In [20]:
import numpy as np
import pickle

from sklearn.naive_bayes import GaussianNB
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split, GridSearchCV

In [21]:
def load_data(filename):
    with open(filename, 'rb') as f:
        x, y = pickle.load(f)

    rows_count = x[0].shape[0]
    x = [bucket for bucket in x if bucket.shape[0] == rows_count]

    x = np.array([bucket.to_numpy().flatten() for bucket in x])
    y = np.array(y[:len(x)])
    
    return x, y

In [22]:
x_data_np, y_data_np = load_data('../processed_data.pkl')
x_gen, y_gen = load_data('./generated_data/processed_data.pkl')
x_train, x_test, y_train, y_test = train_test_split(x_data_np, y_data_np, test_size=0.2, random_state=42)

In [23]:
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import label_binarize


def evaluate_model(clf, x_gen, y_gen):
    y_pred = clf.predict(x_gen)
    y_proba = clf.predict_proba(x_gen)

    print(f'Accuracy: {accuracy_score(y_gen, y_pred)}')
    print(f'Confusion Matrix:\n{confusion_matrix(y_gen, y_pred)}')
    print(f'Classification Report:\n{classification_report(y_gen, y_pred,)}')
    
    target_names = np.unique(y_gen)
    y_test_bin = label_binarize(y_gen, classes=np.arange(len(target_names)))
    

    roc_auc = roc_auc_score(y_test_bin, y_proba, multi_class='ovr')
    
    print(f'ROC-AUC Score: {roc_auc}')

In [24]:

# Train a classifier (example with Gaussian Naive Bayes)
nb_clf = GaussianNB()
nb_clf.fit(x_train, y_train)

evaluate_model(nb_clf, x_test, y_test)
evaluate_model(nb_clf, x_gen, y_gen)

Accuracy: 1.0
Confusion Matrix:
[[3263    0    0]
 [   0  156    0]
 [   0    0 3072]]
Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      3263
           1       1.00      1.00      1.00       156
           2       1.00      1.00      1.00      3072

    accuracy                           1.00      6491
   macro avg       1.00      1.00      1.00      6491
weighted avg       1.00      1.00      1.00      6491

ROC-AUC Score: 1.0
Accuracy: 0.6666666666666666
Confusion Matrix:
[[312   0   0]
 [  0 312   0]
 [  0 312   0]]
Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       312
           1       0.50      1.00      0.67       312
           2       0.00      0.00      0.00       312

    accuracy                           0.67       936
   macro avg       0.50      0.67      0.56       936
weighted avg       0.50      0.67      0.56     

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [25]:

rf_clf = RandomForestClassifier()
rf_clf.fit(x_train, y_train)

evaluate_model(rf_clf, x_test, y_test)
evaluate_model(rf_clf, x_gen, y_gen)


Accuracy: 1.0
Confusion Matrix:
[[3263    0    0]
 [   0  156    0]
 [   0    0 3072]]
Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00      3263
           1       1.00      1.00      1.00       156
           2       1.00      1.00      1.00      3072

    accuracy                           1.00      6491
   macro avg       1.00      1.00      1.00      6491
weighted avg       1.00      1.00      1.00      6491

ROC-AUC Score: 1.0
Accuracy: 0.3333333333333333
Confusion Matrix:
[[  0   0 312]
 [312   0   0]
 [  0   0 312]]
Classification Report:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00       312
           1       0.00      0.00      0.00       312
           2       0.50      1.00      0.67       312

    accuracy                           0.33       936
   macro avg       0.17      0.33      0.22       936
weighted avg       0.17      0.33      0.22     

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
