# Model Evaluation for CoT-KG Network Intrusion Detection

In [None]:
import sys
sys.path.append('..')

from src.data_processing.preprocess import load_and_preprocess_data
from src.models.graph_sage_model import GraphSAGE, evaluate_graph_sage
from src.models.hybrid_model import HybridModel, evaluate_hybrid_model
from src.evaluation.metrics import evaluate_model
from src.visualization.kg_visualizer import visualize_feature_importance

import torch
import pandas as pd
import matplotlib.pyplot as plt

## Load Data and Models

In [None]:
data = load_and_preprocess_data('../data/processed/CICIDS2017_processed.csv')

graph_sage_model = GraphSAGE.load('../models/graph_sage_model.pt')
hybrid_model = HybridModel.load('../models/hybrid_model.pt')

## Evaluate Models

In [None]:
graph_sage_acc, graph_sage_pred = evaluate_graph_sage(graph_sage_model, data)
hybrid_acc, hybrid_pred = evaluate_hybrid_model(hybrid_model, data)

print(f'GraphSAGE Accuracy: {graph_sage_acc:.4f}')
print(f'Hybrid Model Accuracy: {hybrid_acc:.4f}')

## Detailed Evaluation Metrics

In [None]:
graph_sage_cm, graph_sage_report = evaluate_model(data.y[data.test_mask], graph_sage_pred, class_names=data.classes)
hybrid_cm, hybrid_report = evaluate_model(data.y[data.test_mask], hybrid_pred, class_names=data.classes)

print("GraphSAGE Classification Report:")
print(graph_sage_report)
print("\nHybrid Model Classification Report:")
print(hybrid_report)

## Visualize Feature Importance

In [None]:
from src.explainability.integrated_gradients import ExplainabilityAnalyzer

explainer = ExplainabilityAnalyzer(hybrid_model)
feature_importance = explainer.explain(data, target_class=0)  # Explain for the first class
visualize_feature_importance(feature_importance)