In [8]:
import pandas as pd
import numpy as np
import ast
import joblib
import re
from datetime import datetime
from tensorflow.keras.models import load_model
from scapy.all import sniff, DNS, DNSQR, DNSRR
from math import log2
from collections import Counter

In [9]:
# ---------------------------
# Load the saved model and preprocessing objects
# ---------------------------
model = load_model('models/model_batchsize64_sequence5.h5')
le_label = joblib.load('encoders/label_encoder.pkl')
le_tld = joblib.load('encoders/le_dns_top_level_domain.pkl')
le_sld = joblib.load('encoders/le_dns_second_level_domain.pkl')
scaler = joblib.load('encoders/scaler.pkl')



In [10]:
# ---------------------------
# Helper functions for feature extraction
# ---------------------------
def parse_list_safe(list_entry):
    """Safely parse a list from a string or return it if already a list."""
    if isinstance(list_entry, list):
        return list_entry
    elif isinstance(list_entry, str):
        try:
            return ast.literal_eval(list_entry)
        except (ValueError, SyntaxError):
            return []
    else:
        return []

def extract_vowels_consonants(dist):
    """Extract vowel and consonant counts from a character distribution."""
    vowels = set('aeiou')
    consonants = set('bcdfghjklmnpqrstvwxyz')
    vowel_count = 0
    consonant_count = 0
    if isinstance(dist, str):
        try:
            dist = ast.literal_eval(dist)
        except (ValueError, SyntaxError):
            dist = {}
    if isinstance(dist, dict):
        for ch, cnt in dist.items():
            if ch in vowels:
                vowel_count += cnt
            elif ch in consonants:
                consonant_count += cnt
    return vowel_count, consonant_count

In [11]:
# ---------------------------
# Define the expected feature order (must match training)
# ---------------------------
feature_order = [
    'dns_domain_name_length',
    'numerical_percentage',
    'character_entropy',
    'max_continuous_numeric_len',
    'max_continuous_alphabet_len',
    'vowels_consonant_ratio',
    'conv_freq_vowels_consonants',
    'packets_numbers',
    'receiving_packets_numbers',
    'sending_packets_numbers',
    'receiving_bytes',
    'sending_bytes',
    'distinct_ttl_values',
    'ttl_values_min',
    'ttl_values_max',
    'ttl_values_mean',
    'dns_top_level_domain_encoded',
    'dns_second_level_domain_encoded',
    'uni_gram_count',
    'bi_gram_count',
    'tri_gram_count',
    'query_resource_record_type_count',
    'ans_resource_record_type_count',
    'query_resource_record_class_count',
    'ans_resource_record_class_count',
    'vowel_count',
    'consonant_count'
]

In [12]:
# ---------------------------
# Function to preprocess a single DNS event
# ---------------------------
def preprocess_event(event):
    # Convert timestamp (if needed; here we keep it for potential logging)
    try:
        _ = pd.to_datetime(event.get('timestamp'))
    except Exception:
        pass  # Inference does not use the timestamp for scaling

    # Process DNS domains
    placeholder = 'unknown'
    tld = event.get('dns_top_level_domain', placeholder) or placeholder
    sld = event.get('dns_second_level_domain', placeholder) or placeholder
    # If not in the known classes, default to 'unknown'
    if tld not in le_tld.classes_:
        tld = 'unknown'
    if sld not in le_sld.classes_:
        sld = 'unknown'
    tld_encoded = int(le_tld.transform([tld])[0])
    sld_encoded = int(le_sld.transform([sld])[0])
    
    # Process n-gram features by computing counts
    uni_gram_list = parse_list_safe(event.get('uni_gram_domain_name', []))
    bi_gram_list = parse_list_safe(event.get('bi_gram_domain_name', []))
    tri_gram_list = parse_list_safe(event.get('tri_gram_domain_name', []))
    uni_gram_count = len(uni_gram_list)
    bi_gram_count = len(bi_gram_list)
    tri_gram_count = len(tri_gram_list)
    
    # Process resource record features by computing unique counts
    def count_unique(val):
        lst = parse_list_safe(val)
        return len(set(lst)) if lst else 0
    query_rr_type_count = count_unique(event.get('query_resource_record_type', []))
    ans_rr_type_count = count_unique(event.get('ans_resource_record_type', []))
    query_rr_class_count = count_unique(event.get('query_resource_record_class', []))
    ans_rr_class_count = count_unique(event.get('ans_resource_record_class', []))
    
    # Extract vowel and consonant counts from character distribution
    vowel_count, consonant_count = extract_vowels_consonants(event.get('character_distribution', {}))
    
    # conv_freq_vowels_consonants may not be present in the event;
    # set a default value (adjust as needed)
    conv_freq = event.get('conv_freq_vowels_consonants', 0.0)
    
    # Build a dictionary of features that matches the order used during training.
    features = {
        'dns_domain_name_length': event.get('dns_domain_name_length', 0),
        'numerical_percentage': event.get('numerical_percentage', 0.0),
        'character_entropy': event.get('character_entropy', 0.0),
        'max_continuous_numeric_len': event.get('max_continuous_numeric_len', 0),
        'max_continuous_alphabet_len': event.get('max_continuous_alphabet_len', 0),
        'packets_numbers': event.get('packets_numbers', 0),
        'receiving_packets_numbers': event.get('receiving_packets_numbers', 0),
        'sending_packets_numbers': event.get('sending_packets_numbers', 0),
        'receiving_bytes': event.get('receiving_bytes', 0),
        'sending_bytes': event.get('sending_bytes', 0),
        'distinct_ttl_values': event.get('distinct_ttl_values', 0),
        'ttl_values_min': event.get('ttl_values_min', -1),
        'ttl_values_max': event.get('ttl_values_max', -1),
        'ttl_values_mean': event.get('ttl_values_mean', -1.0),
        'uni_gram_count': uni_gram_count,
        'bi_gram_count': bi_gram_count,
        'tri_gram_count': tri_gram_count,
        'query_resource_record_type_count': query_rr_type_count,
        'ans_resource_record_type_count': ans_rr_type_count,
        'query_resource_record_class_count': query_rr_class_count,
        'ans_resource_record_class_count': ans_rr_class_count,
        'vowels_consonant_ratio': event.get('vowels_consonant_ratio', 0.0),
        'conv_freq_vowels_consonants': conv_freq,
        'vowel_count': vowel_count,
        'consonant_count': consonant_count,
        'dns_top_level_domain_encoded': tld_encoded,
        'dns_second_level_domain_encoded': sld_encoded
    }
    return features

In [13]:
# ---------------------------
# Buffer for maintaining a sliding window (sequence) of DNS events
# ---------------------------
sequence_length = 5
event_buffer = []

In [14]:
# ---------------------------
# Function to predict the label for a DNS event using the sliding window
# ---------------------------
def predict_dns_event(event):
    # Preprocess the incoming event
    features = preprocess_event(event)
    
    # Create a DataFrame (ensuring columns follow the expected order)
    df_event = pd.DataFrame([features], columns=feature_order)
    
    # Scale the features (all columns are numeric; timestamp is not used here)
    scaled_features = scaler.transform(df_event)
    
    # Update the buffer: append the new scaled event
    event_buffer.append(scaled_features[0])
    
    # If we don't have enough events, duplicate the current event to pad the buffer
    while len(event_buffer) < sequence_length:
        event_buffer.insert(0, scaled_features[0])
    
    # Ensure the buffer contains only the last 'sequence_length' events
    if len(event_buffer) > sequence_length:
        event_buffer.pop(0)
    
    # Form a sequence array for the model input with shape (1, sequence_length, num_features)
    sequence_input = np.array(event_buffer).reshape(1, sequence_length, -1)
    
    # Get prediction probabilities from the model
    pred_probs = model.predict(sequence_input)
    pred_class = np.argmax(pred_probs, axis=1)[0]
    predicted_label = le_label.inverse_transform([pred_class])[0]
    
    return predicted_label, pred_probs

In [17]:
def process_packet(packet):
    if packet.haslayer(DNS):
        dns_layer = packet.getlayer(DNS)
        event_type = "Query" if dns_layer.qr == 0 else "Response"
        # Extract query name from DNSQR or DNSRR
        if packet.haslayer(DNSQR):
            query_name = packet[DNSQR].qname.decode('utf-8').strip('.')
        elif packet.haslayer(DNSRR):
            query_name = packet[DNSRR].rrname.decode('utf-8').strip('.')
        else:
            query_name = "unknown"
        
        event = {}
        event['timestamp'] = datetime.fromtimestamp(packet.time).strftime("%Y-%m-%d %H:%M:%S.%f")
        parts = query_name.split('.')
        if len(parts) >= 2:
            event['dns_top_level_domain'] = parts[-1]
            event['dns_second_level_domain'] = parts[-2]
        else:
            event['dns_top_level_domain'] = 'unknown'
            event['dns_second_level_domain'] = 'unknown'
        event['dns_domain_name_length'] = len(query_name)
        digits = sum(c.isdigit() for c in query_name)
        event['numerical_percentage'] = digits / len(query_name) if query_name else 0.0
        freq = {}
        for c in query_name:
            freq[c] = freq.get(c, 0) + 1
        entropy = -sum((count/len(query_name)) * log2(count/len(query_name)) for count in freq.values()) if query_name else 0.0
        event['character_entropy'] = entropy
        numeric_runs = re.findall(r'\d+', query_name)
        event['max_continuous_numeric_len'] = max((len(run) for run in numeric_runs), default=0)
        alpha_runs = re.findall(r'[a-zA-Z]+', query_name)
        event['max_continuous_alphabet_len'] = max((len(run) for run in alpha_runs), default=0)
        event['packets_numbers'] = 1
        event['receiving_packets_numbers'] = 1 if dns_layer.qr == 1 else 0
        event['sending_packets_numbers'] = 1 if dns_layer.qr == 0 else 0
        event['receiving_bytes'] = len(packet) if dns_layer.qr == 1 else 0
        event['sending_bytes'] = len(packet) if dns_layer.qr == 0 else 0
        if packet.haslayer(DNSRR):
            ttl = packet[DNSRR].ttl
            event['distinct_ttl_values'] = 1
            event['ttl_values_min'] = ttl
            event['ttl_values_max'] = ttl
            event['ttl_values_mean'] = float(ttl)
        else:
            event['distinct_ttl_values'] = 0
            event['ttl_values_min'] = -1
            event['ttl_values_max'] = -1
            event['ttl_values_mean'] = -1.0
        event['query_resource_record_type'] = []
        event['ans_resource_record_type'] = []
        event['query_resource_record_class'] = []
        event['ans_resource_record_class'] = []
        def create_ngrams(s, n):
            return [s[i:i+n] for i in range(len(s)-n+1)]
        event['uni_gram_domain_name'] = create_ngrams(query_name, 1)
        event['bi_gram_domain_name'] = create_ngrams(query_name, 2)
        event['tri_gram_domain_name'] = create_ngrams(query_name, 3)
        event['character_distribution'] = dict(Counter(query_name))
        vowels = set('aeiouAEIOU')
        consonants = set('bcdfghjklmnpqrstvwxyzBCDFGHJKLMNPQRSTVWXYZ')
        vowel_count = sum(1 for c in query_name if c in vowels)
        consonant_count = sum(1 for c in query_name if c in consonants)
        event['vowels_consonant_ratio'] = vowel_count / consonant_count if consonant_count > 0 else 0.0

        label, probs = predict_dns_event(event)
        print(f"[{event_type}]: {query_name}: {label}")
        print()

if __name__ == '__main__':
    print("Starting real-time DNS interceptor...")
    sniff(filter="udp port 53", prn=process_packet, store=0)

Starting real-time DNS interceptor...
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/ste ━━━━━━━━━━━━━━━━━━━━ 0s 44ms/step
[Query]: api.teleparty.com: Benign

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/ste ━━━━━━━━━━━━━━━━━━━━ 0s 40ms/step
[Query]: api.teleparty.com: Benign

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/ste ━━━━━━━━━━━━━━━━━━━━ 0s 49ms/step
[Query]: api.teleparty.com: Benign

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 24ms/ste ━━━━━━━━━━━━━━━━━━━━ 0s 52ms/step
[Response]: api.teleparty.com: Benign

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/ste ━━━━━━━━━━━━━━━━━━━━ 0s 43ms/step
[Response]: api.teleparty.com: Benign

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/ste ━━━━━━━━━━━━━━━━━━━━ 0s 47ms/step
[Response]: api.teleparty.com: Benign

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/ste ━━━━━━━━━━━━━━━━━━━━ 0s 43ms/step
[Query]: www.bing.com: Benign

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/ste ━━━━━━━━━━━━━━━━━━━━ 0s 40ms/step
[Query]: www.bing.com: Benign

1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/ste ━━━━━━━━━━━━━━━━━━━━ 0s 41ms/step
[Query]: www.bing.com: Benign

1/1 ━