In [1]:
pip install pandas numpy scikit-learn lime shap matplotlib lightgbm xgboost ipykernel

Note: you may need to restart the kernel to use updated packages.


In [2]:
# Import required libraries
import pandas as pd
import numpy as np
import pickle
import lime
import lime.lime_tabular
import shap
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler, LabelEncoder
import os

In [3]:
def preprocess_network_features(df):
    """Preprocess network traffic features"""
    df_processed = df.copy()
    
    # List of columns to drop
    cols_to_drop = [
        'FlowID', 'Timestamp', 'SourceIP', 'DestinationIP', 'Label'
    ]
    
    # Drop unwanted columns if they exist
    for col in cols_to_drop:
        if col in df_processed.columns:
            df_processed.drop(col, axis=1, inplace=True)
    
    # Handle categorical variables
    label_encoders = {}
    for column in df_processed.columns:
        if df_processed[column].dtype == 'object':
            label_encoders[column] = LabelEncoder()
            df_processed[column] = label_encoders[column].fit_transform(df_processed[column].astype(str))
    
    return df_processed, label_encoders

In [4]:
def load_and_prepare_data(test_file='Dataset/test_data.csv', sample_size=None):
    """Load and prepare the test data"""
    print("Loading test data...")
    test_data = pd.read_csv(test_file)
    print(f"Loaded {len(test_data)} rows of test data")
    
    # Take a random sample if specified (used for SHAP analysis)
    if sample_size and len(test_data) > sample_size:
        test_data = test_data.sample(n=sample_size, random_state=42)
        print(f"Using {sample_size} random samples for analysis")
    
    # Separate features and target
    if 'Label' in test_data.columns:
        y_test = test_data['Label']
        X_test = test_data.drop('Label', axis=1)
    else:
        X_test = test_data
        y_test = None
    
    # Preprocess features
    print("\nPreprocessing features...")
    X_test_processed, label_encoders = preprocess_network_features(X_test)
    
    # Store original data for reference (used in LIME)
    X_test_original = X_test_processed.copy()
    
    # Scale the features
    print("\nScaling features...")
    scaler = StandardScaler()
    X_test_scaled = scaler.fit_transform(X_test_processed)
    
    # Convert scaled data back to DataFrame to keep feature names
    X_test_scaled = pd.DataFrame(X_test_scaled, columns=X_test_processed.columns)
    
    return X_test_scaled, X_test_processed.columns.tolist(), X_test_original, y_test

In [5]:

# LIME Explanation Functions
def create_lime_explanation(model, X_test, feature_names, X_test_original, 
                          instance_idx=0, num_features=5):
    
    # Create the LIME explainer
    explainer = lime.lime_tabular.LimeTabularExplainer(
        training_data=X_test.values,
        feature_names=feature_names,
        class_names=['benign', 'malicious'],
        mode='classification'
    )
    
    # Get the explanation
    instance = X_test.iloc[instance_idx].values
    exp = explainer.explain_instance(
        instance, 
        model.predict_proba,
        num_features=num_features
    )
    
    # Get prediction probabilities
    pred_proba = model.predict_proba(instance.reshape(1, -1))[0]
    
    # Create figure with custom layout
    fig = plt.figure(figsize=(15, 8))
    
    # Create GridSpec
    gs = plt.GridSpec(2, 1, height_ratios=[1, 2])
    
    # Create prediction probability subplot (top)
    ax_pred = plt.subplot(gs[0])
    ax_pred.set_title("Prediction probabilities", pad=20)
    
    # Create bars for both classes
    ax_pred.barh(['malicious'], [pred_proba[1]], color='orange', height=0.3)
    ax_pred.barh(['benign'], [pred_proba[0]], color='blue', height=0.3)
    
    # Customize prediction probability axis
    ax_pred.set_xlim(0, 1)
    ax_pred.grid(True)
    ax_pred.set_xlabel('Probability')
    
    # Create feature importance subplot (bottom)
    ax_imp = plt.subplot(gs[1])
    ax_imp.set_title("Feature importance", pad=20)
    
    # Get feature importance data
    exp_list = exp.as_list()
    features = [x[0] for x in exp_list]
    scores = [x[1] for x in exp_list]
    
    # Create feature importance bars
    colors = ['orange' if score > 0 else 'blue' for score in scores]
    y_pos = range(len(features))
    ax_imp.barh(y_pos, scores, color=colors)
    
    # Customize feature importance axis
    ax_imp.set_yticks(y_pos)
    ax_imp.set_yticklabels(features)
    ax_imp.set_xlabel('Impact on prediction')
    ax_imp.grid(True)
    
    # Adjust layout
    plt.tight_layout()
    
    return exp, fig

In [6]:

def create_lime_explanations():
    """Create LIME explanations for all models"""
    print("\n=== Generating LIME Explanations ===")
    
    # Load and prepare data
    X_test_scaled, feature_names, X_test_original, y_test = load_and_prepare_data()
    
    # List of model files
    model_files = ['Random_Forest.pkl', 'Decision_Tree.pkl', 'LightGBM.pkl', 'XGBoost.pkl']
    
    # Create explanations for each model
    for model_file in model_files:
        try:
            # Load model
            model_path = os.path.join('Models', model_file)
            if not os.path.exists(model_path):
                print(f"Model file not found: {model_path}")
                continue
                
            with open(model_path, 'rb') as f:
                model = pickle.load(f)
            
            # Create explanation
            exp, fig = create_lime_explanation(
                model=model,
                X_test=X_test_scaled,
                feature_names=feature_names,
                X_test_original=X_test_original,
                instance_idx=0,  # First instance
                num_features=5   # Show top 5 features
            )
            
            # Save the figure in lime_charts directory
            os.makedirs('lime_charts', exist_ok=True)
            output_file = os.path.join('lime_charts', f'lime_explanation_{model_file.split(".")[0]}.png')
            fig.savefig(output_file)
            plt.close(fig)
            
            print(f"LIME explanation saved as {output_file}")
            
        except Exception as e:
            print(f"Error processing {model_file}: {str(e)}")
            import traceback
            print(traceback.format_exc())

In [7]:
# SHAP Explanation Functions
def create_shap_explanation(model, X_test, feature_names, model_name):
    """Create SHAP explanations with waterfall plots for global feature importance"""
    
    # Create SHAP explainer
    explainer = shap.Explainer(model, X_test, feature_names=feature_names)
    
    # Calculate SHAP values for all instances
    shap_values = explainer(X_test)
    
    # For binary classification, use the positive class values if needed
    if shap_values.shape[-1] == 2:  # Binary classification
        shap_values = shap_values[:, :, 1]
    
    # Calculate mean absolute SHAP values for global importance
    mean_abs_shap = np.abs(shap_values.values).mean(0)
    feature_importance = pd.DataFrame(list(zip(feature_names, mean_abs_shap)), 
                                    columns=['feature', 'importance'])
    feature_importance = feature_importance.sort_values('importance', ascending=True)
    
    # Create bar plot for global feature importance
    plt.figure(figsize=(12, 18))
    plt.rcParams.update({'font.size': 6})
    
    # Create the horizontal bar plot
    ax = plt.gca()
    bars = ax.barh(range(len(feature_importance)), feature_importance['importance'])
    
    # Customize the plot
    plt.title(f"SHAP Global Feature Importance - {model_name}", fontsize=8, pad=20)
    plt.xlabel('Mean |SHAP value|', fontsize=7)
    
    # Set y-tick positions and labels
    ax.set_yticks(range(len(feature_importance)))
    ax.set_yticklabels(feature_importance['feature'], rotation=0, ha='right', fontsize=6)
    
    # Adjust layout
    plt.subplots_adjust(left=0.5, right=0.95, top=0.95, bottom=0.05)
    ax.tick_params(axis='y', pad=5)
    
    # Save bar plot
    os.makedirs(os.path.join('shap_charts', 'Global Importance'), exist_ok=True)
    bar_plot_file = os.path.join('shap_charts', 'Global Importance', f'shap_bar_{model_name}.png')
    plt.savefig(bar_plot_file, bbox_inches='tight', dpi=300)
    plt.close()
    print(f"Global importance bar plot saved as {bar_plot_file}")
    
    # Create and save beeswarm plot
    plt.figure(figsize=(12, 8))
    shap.plots.beeswarm(shap_values, show=False)
    plt.title(f"SHAP Global Feature Importance - {model_name} (Beeswarm Plot)")
    plt.tight_layout()
    beeswarm_plot_file = os.path.join('shap_charts', 'Global Importance', f'shap_beeswarm_{model_name}.png')
    plt.savefig(beeswarm_plot_file, bbox_inches='tight', dpi=300)
    plt.close()
    print(f"Beeswarm plot saved as {beeswarm_plot_file}")

In [8]:
def create_shap_explanations():
    """Create SHAP explanations for all models"""
    print("\n=== Generating SHAP Explanations ===")
    
    # Load and prepare data (with sampling for SHAP)
    X_test_scaled, feature_names, _, y_test = load_and_prepare_data(sample_size=1000)
    
    # List of model files
    model_files = {
        'Random_Forest.pkl': 'Random_Forest',
        'Decision_Tree.pkl': 'Decision_Tree',
        'LightGBM.pkl': 'LightGBM',
        'XGBoost.pkl': 'XGBoost'
    }
    
    # Create explanations for each model
    for model_file, model_name in model_files.items():
        try:
            # Load model
            model_path = os.path.join('Models', model_file)
            if not os.path.exists(model_path):
                print(f"Model file not found: {model_path}")
                continue
                
            with open(model_path, 'rb') as f:
                model = pickle.load(f)
            
            # Create explanation
            create_shap_explanation(
                model=model,
                X_test=X_test_scaled,
                feature_names=feature_names,
                model_name=model_name
            )
            
        except Exception as e:
            print(f"Error processing {model_file}: {str(e)}")
            import traceback
            print(traceback.format_exc())

In [9]:
if __name__ == "__main__":
    # Create LIME explanations
    create_lime_explanations()
    
    # Create SHAP explanations
    create_shap_explanations()



=== Generating LIME Explanations ===
Loading test data...
Loaded 639695 rows of test data

Preprocessing features...

Scaling features...

Creating LIME explanation for Random_Forest.pkl
LIME explanation saved as lime_charts/lime_explanation_Random_Forest.png

Creating LIME explanation for Decision_Tree.pkl
LIME explanation saved as lime_charts/lime_explanation_Decision_Tree.png

Creating LIME explanation for LightGBM.pkl




LIME explanation saved as lime_charts/lime_explanation_LightGBM.png

Creating LIME explanation for XGBoost.pkl
LIME explanation saved as lime_charts/lime_explanation_XGBoost.png

=== Generating SHAP Explanations ===
Loading test data...
Loaded 639695 rows of test data
Using 1000 random samples for analysis

Preprocessing features...

Scaling features...

Creating SHAP explanation for Random_Forest...




Global importance bar plot saved as shap_charts/Global Importance/shap_bar_Random_Forest.png
Beeswarm plot saved as shap_charts/Global Importance/shap_beeswarm_Random_Forest.png

Creating SHAP explanation for Decision_Tree...
Global importance bar plot saved as shap_charts/Global Importance/shap_bar_Decision_Tree.png
Beeswarm plot saved as shap_charts/Global Importance/shap_beeswarm_Decision_Tree.png

Creating SHAP explanation for LightGBM...
Global importance bar plot saved as shap_charts/Global Importance/shap_bar_LightGBM.png
Beeswarm plot saved as shap_charts/Global Importance/shap_beeswarm_LightGBM.png

Creating SHAP explanation for XGBoost...
Global importance bar plot saved as shap_charts/Global Importance/shap_bar_XGBoost.png
Beeswarm plot saved as shap_charts/Global Importance/shap_beeswarm_XGBoost.png
