In [1]:
pip install pandas numpy scikit-learn matplotlib seaborn cryptography



In [3]:
# SIT326 - Pass Task 4: Machine Learning for Malicious Traffic Detection

# Import necessary libraries
import socket
import ssl
from datetime import datetime
import pandas as pd
import numpy as np
from cryptography import x509
from cryptography.x509.oid import NameOID, ExtensionOID
from cryptography.hazmat.primitives import hashes
import os
import logging
from tqdm import tqdm # For progress bars

# --- Machine Learning Imports ---
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
from sklearn import tree
import seaborn as sns

# --- Configuration ---
BENIGN_HOSTS_FILE = '/content/drive/MyDrive/SIT326P4/benign_hosts.txt'
MALICIOUS_HOSTS_FILE = '/content/drive/MyDrive/SIT326P4/malicious_hosts.txt'
OUTPUT_CSV_FILE = '/content/drive/MyDrive/SIT326P4/tls_certificates_data.csv'
DECISION_TREE_IMG = '/content/drive/MyDrive/SIT326P4/decision_tree.png'
FEATURE_IMPORTANCE_IMG = '/content/drive/MyDrive/SIT326P4/feature_importance.png'

# Setup basic logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# ==============================================================================
# PART 1: DATA COLLECTION & FEATURE EXTRACTION
# ==============================================================================

def get_certificate_details(hostname, port=443):
    """
    Connects to a host and retrieves its TLS certificate details.
    Returns a dictionary of features or None if connection fails.
    """
    try:
        # Create a default SSL context
        context = ssl.create_default_context()
        # Connect to the server
        with socket.create_connection((hostname, port), timeout=5) as sock:
            with context.wrap_socket(sock, server_hostname=hostname) as ssock:
                # Get the certificate in binary DER format
                cert_der = ssock.getpeercert(True)
                # Parse the certificate using the cryptography library
                cert = x509.load_der_x509_certificate(cert_der)

                # --- Feature Extraction ---
                features = {}

                # 1. Validity Period (in days)
                validity_period = (cert.not_valid_after - cert.not_valid_before).days
                features['validity_period'] = validity_period

                # 2. Public Key Length (in bits)
                key_length = cert.public_key().key_size
                features['key_length'] = key_length

                # 3. Signature Algorithm
                sig_algo = cert.signature_hash_algorithm.name
                features['signature_algorithm'] = sig_algo.upper()

                # 4. Issuer Organization
                try:
                    issuer_org = cert.issuer.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)[0].value
                except IndexError:
                    issuer_org = 'N/A' # Handle cases where Org is not present
                features['issuer_org'] = issuer_org

                # 5. Subject Alternative Name (SAN) Count
                try:
                    san_extension = cert.extensions.get_extension_for_oid(ExtensionOID.SUBJECT_ALTERNATIVE_NAME)
                    san_count = len(san_extension.value.get_values_for_type(x509.DNSName))
                except x509.ExtensionNotFound:
                    san_count = 0
                features['san_count'] = san_count

                # 6. Self-Signed Status (1 if self-signed, 0 otherwise)
                is_self_signed = 1 if cert.issuer == cert.subject else 0
                features['self_signed'] = is_self_signed

                return features

    except (socket.gaierror, socket.timeout, ConnectionRefusedError, ssl.SSLError, ssl.CertificateError, OSError) as e:
        # logging.warning(f"Could not connect to or get cert for {hostname}: {e}")
        return None

def collect_data(host_file, label):
    """
    Reads a file of hostnames, collects certificate data for each,
    and returns a list of dictionaries.
    """
    certificates_data = []
    with open(host_file, 'r') as f:
        hosts = [line.strip() for line in f if line.strip()]

    # Use tqdm for a nice progress bar
    for host in tqdm(hosts, desc=f"Processing {label} hosts"):
        details = get_certificate_details(host)
        if details:
            details['label'] = label # Add the label (benign/malicious)
            certificates_data.append(details)
    return certificates_data

# ==============================================================================
# PART 2: MODEL TRAINING AND EVALUATION (Modified from original script)
# ==============================================================================

def train_and_evaluate(data_path):
    """
    Loads data, trains a Decision Tree model, evaluates it, and saves visualizations.
    """
    logging.info("Starting model training and evaluation process...")

    # 1. Load and preprocess the dataset
    try:
        data = pd.read_csv(data_path)
    except FileNotFoundError:
        logging.error(f"Data file not found at {data_path}. Please run the data collection part first.")
        return

    # Convert label to numerical (0 for benign, 1 for malicious)
    data['label'] = data['label'].apply(lambda x: 1 if x == 'malicious' else 0)

    # Handle potential missing values by filling with a placeholder or median
    data['issuer_org'].fillna('N/A', inplace=True)
    for col in ['validity_period', 'key_length', 'san_count']:
        data[col].fillna(data[col].median(), inplace=True)


    X = data.drop('label', axis=1)
    y = data['label']

    # Define categorical and numerical features
    categorical_features = ['signature_algorithm', 'issuer_org']
    numerical_features = ['validity_period', 'key_length', 'san_count', 'self_signed']

    # Create preprocessing pipelines for numerical and categorical features
    # Numerical features will be scaled. Categorical features will be one-hot encoded.
    preprocessor = ColumnTransformer(
        transformers=[
            ('num', StandardScaler(), numerical_features),
            ('cat', OneHotEncoder(handle_unknown='ignore', drop='first'), categorical_features)
        ],
        remainder='passthrough' # Keep other columns (if any)
    )

    # 2. Split data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, stratify=y, random_state=42)

    # 3. Create a full pipeline with preprocessing and the model
    # This prevents data leakage from the test set during scaling/encoding
    pipeline = Pipeline(steps=[
        ('preprocessor', preprocessor),
        ('classifier', DecisionTreeClassifier(random_state=42))
    ])

    # Define hyperparameter grid for GridSearchCV
    param_grid = {
        'classifier__max_depth': [3, 5, 7, 10, None],
        'classifier__min_samples_split': [2, 5, 10],
        'classifier__min_samples_leaf': [1, 2, 4]
    }

    # Use GridSearchCV to find the best hyperparameters
    grid_search = GridSearchCV(pipeline, param_grid, cv=5, scoring='accuracy', n_jobs=-1)
    grid_search.fit(X_train, y_train)

    best_model = grid_search.best_estimator_
    logging.info(f"Best Parameters found by GridSearchCV: {grid_search.best_params_}")

    # 4. Evaluate the best model on the test set
    y_pred = best_model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    precision = precision_score(y_test, y_pred)
    recall = recall_score(y_test, y_pred)
    f1 = f1_score(y_test, y_pred)
    conf_matrix = confusion_matrix(y_test, y_pred)

    print("\n" + "="*30)
    print("      MODEL EVALUATION RESULTS")
    print("="*30)
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-Score:  {f1:.4f}")
    print("\nConfusion Matrix:")
    print(conf_matrix)
    print("="*30 + "\n")

    # 5. Visualize Decision Tree
    # Get feature names after one-hot encoding for the plot
    try:
        feature_names = (numerical_features +
                         list(best_model.named_steps['preprocessor']
                                       .named_transformers_['cat']
                                       .get_feature_names_out(categorical_features)))

        plt.figure(figsize=(25, 15))
        tree.plot_tree(best_model.named_steps['classifier'],
                      feature_names=feature_names,
                      class_names=['Benign', 'Malicious'],
                      filled=True,
                      rounded=True,
                      fontsize=10)
        plt.title("Decision Tree for Malicious Certificate Detection")
        plt.savefig(DECISION_TREE_IMG)
        plt.close()
        logging.info(f"Decision tree visualization saved to {DECISION_TREE_IMG}")
    except Exception as e:
        logging.error(f"Could not generate decision tree plot: {e}")


    # 6. Plot feature importance
    # Extract importance from the classifier step of the pipeline
    importances = best_model.named_steps['classifier'].feature_importances_

    feature_importance_df = pd.DataFrame({
        'Feature': feature_names,
        'Importance': importances
    }).sort_values('Importance', ascending=False)

    plt.figure(figsize=(12, 8))
    sns.barplot(x='Importance', y='Feature', data=feature_importance_df.head(20)) # Show top 20 features
    plt.title('Top 20 Feature Importances for Decision Tree Classifier')
    plt.tight_layout()
    plt.savefig(FEATURE_IMPORTANCE_IMG)
    plt.close()
    logging.info(f"Feature importance plot saved to {FEATURE_IMPORTANCE_IMG}")


# ==============================================================================
# MAIN EXECUTION BLOCK
# ==============================================================================
if __name__ == "__main__":
    # --- Step 1: Collect Data ---
    # Check if data file already exists to avoid re-collecting
    if not os.path.exists(OUTPUT_CSV_FILE):
        logging.info("Starting data collection phase...")

        if not os.path.exists(BENIGN_HOSTS_FILE) or not os.path.exists(MALICIOUS_HOSTS_FILE):
            logging.error("Host files not found! Please create 'benign_hosts.txt' and 'malicious_hosts.txt'.")
        else:
            benign_data = collect_data(BENIGN_HOSTS_FILE, 'benign')
            malicious_data = collect_data(MALICIOUS_HOSTS_FILE, 'malicious')

            all_data = benign_data + malicious_data

            if not all_data:
                logging.error("No certificate data was collected. Check your host files and network connection.")
            else:
                # Convert to DataFrame and save to CSV
                df = pd.DataFrame(all_data)
                df.to_csv(OUTPUT_CSV_FILE, index=False)
                logging.info(f"Successfully collected data for {len(df)} certificates.")
                logging.info(f"Dataset saved to '{OUTPUT_CSV_FILE}'. This file is preserved for forensic analysis.")
    else:
        logging.info(f"Data file '{OUTPUT_CSV_FILE}' already exists. Skipping data collection.")

    # --- Step 2: Train and Evaluate Model ---
    if os.path.exists(OUTPUT_CSV_FILE):
        train_and_evaluate(OUTPUT_CSV_FILE)
    else:
        logging.error("Cannot proceed to training as no data file is available.")

  validity_period = (cert.not_valid_after - cert.not_valid_before).days
  validity_period = (cert.not_valid_after - cert.not_valid_before).days
Processing benign hosts: 100%|██████████| 2000/2000 [02:57<00:00, 11.24it/s]
Processing malicious hosts: 100%|██████████| 2055/2055 [06:18<00:00,  5.43it/s]
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  data['issuer_org'].fillna('N/A', inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inpl


      MODEL EVALUATION RESULTS
Accuracy:  1.0000
Precision: 1.0000
Recall:    1.0000
F1-Score:  1.0000

Confusion Matrix:
[[600   0]
 [  0 340]]

