# 3. Model Evaluation and Export

**Project:** IoT Network Attack Detection using Federated Learning  
**Author:** Nguyen Duc Thang

---

## üìã Objectives

1. Load trained global model
2. Generate predictions on test set
3. Calculate comprehensive metrics:
   - Overall accuracy
   - Per-class Precision, Recall, F1-Score
   - Confusion Matrix
4. Create visualizations for thesis report
5. Export all metrics and artifacts

---

## üéØ Expected Outputs

- `../Output/metrics/confusion_matrix.png`
- `../Output/metrics/accuracy_plot.png`
- `../Output/metrics/f1_scores_per_class.png`
- `../Output/metrics/metrics_report.json`


## 1. Setup and Imports


In [None]:
# Standard libraries
import os
import sys
import numpy as np
import pandas as pd
import json
import pickle
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    confusion_matrix, classification_report, 
    accuracy_score, precision_recall_fscore_support
)

# TensorFlow/Keras
import tensorflow as tf
from tensorflow import keras

# Import our utility modules
from utils import data_utils, model_utils

# Set style for plots
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 100
plt.rcParams['savefig.dpi'] = 300  # High resolution for thesis

print("‚úÖ All imports successful!")

## 2. Load Trained Model and Test Data


In [None]:
# ============================================================================
# LOAD TRAINED MODEL (Framework-Agnostic)
# ============================================================================

import yaml
import json

# Load config to determine framework
config_path = 'configs/training_config.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

framework = config.get('framework', 'tensorflow')

print(f"üîß Framework: {framework.upper()}")
print(f"{'='*80}\n")

if framework == 'pytorch':
    # ========== LOAD PYTORCH MODEL ==========
    import torch
    from utils.model_utils_pytorch import create_tabtransformer_from_config
    
    # Load feature config
    feature_config_path = '../Output/models/feature_config.json'
    with open(feature_config_path, 'r') as f:
        feature_config = json.load(f)
    
    config['features'] = feature_config
    
    print("üìÇ Loading PyTorch TabTransformer model...")
    
    # Create model architecture
    model = create_tabtransformer_from_config(config)
    
    # Load trained weights
    model_path = '../Output/models/global_model.pth'
    model.load_state_dict(torch.load(model_path, map_location='cpu'))
    
    # Set to evaluation mode
    model.eval()
    
    # Set device
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    
    print(f"   ‚úì Model loaded from: {model_path}")
    print(f"   ‚úì Device: {device}")
    
else:
    # ========== LOAD TENSORFLOW MODEL (EXISTING) ==========
    from tensorflow import keras
    
    print("üìÇ Loading TensorFlow DNN model...")
    
    model_path = '../Output/models/global_model.h5'
    model = keras.models.load_model(model_path)
    
    print(f"   ‚úì Model loaded from: {model_path}")

print(f"\n{'='*80}")

# Load test data
data_dir = '../Output/data'
test_data = data_utils.load_client_data(data_dir, 'test')
X_test = test_data['X']
y_test = test_data['y']

print(f"\n‚úÖ Data loaded:")
print(f"   Test samples: {len(X_test):,}")
print(f"   Features: {X_test.shape[1]}")
print(f"   Classes: {len(np.unique(y_test))}")

## 3. Load Label Mapping


In [None]:
# Load label mapping
labels_path = '../Output/models/labels.json'
with open(labels_path, 'r') as f:
    label_mapping = json.load(f)

# Convert keys to integers
label_mapping = {int(k): v for k, v in label_mapping.items()}

print(f"üìã Label mapping loaded ({len(label_mapping)} classes):")
for i in range(min(10, len(label_mapping))):
    print(f"   {i}: {label_mapping[i]}")
if len(label_mapping) > 10:
    print(f"   ... and {len(label_mapping) - 10} more")

## 4. Generate Predictions


In [None]:
# ============================================================================
# GENERATE PREDICTIONS (Framework-Agnostic)
# ============================================================================

print("üîÆ Generating predictions on test set...\n")

if framework == 'pytorch':
    # ========== PYTORCH PREDICTIONS ==========
    import torch
    from utils.fl_utils_pytorch import split_features
    
    # Convert to tensor
    X_test_tensor = torch.FloatTensor(X_test).to(device)
    
    # Number of categorical features
    num_categorical = feature_config['num_categorical']
    
    # Generate predictions
    with torch.no_grad():
        # Split features
        cat_features, num_features = split_features(
            X_test_tensor.cpu().numpy(),
            num_categorical,
            feature_config['categorical_cardinalities']
        )
        
        # Convert to tensors
        cat_features = torch.LongTensor(cat_features).to(device)
        num_features = torch.FloatTensor(num_features).to(device)
        
        # Get predictions
        logits = model(cat_features, num_features)
        y_pred = logits.argmax(dim=1).cpu().numpy()
    
    print(f"   ‚úì Generated {len(y_pred):,} predictions using PyTorch")
    
else:
    # ========== TENSORFLOW PREDICTIONS (EXISTING) ==========
    y_pred_proba = model.predict(X_test, verbose=1)
    y_pred = np.argmax(y_pred_proba, axis=1)
    
    print(f"   ‚úì Generated {len(y_pred):,} predictions using TensorFlow")

print(f"\n‚úÖ Predictions complete!")
print(f"   Test samples: {len(y_pred):,}")
print(f"   Unique predictions: {len(np.unique(y_pred))}")

## 5. Calculate Overall Metrics


In [None]:
# Calculate overall accuracy
overall_accuracy = accuracy_score(y_test, y_pred)

# Calculate per-class metrics
precision, recall, f1, support = precision_recall_fscore_support(
    y_test, y_pred, average=None, zero_division=0
)

# Calculate macro and weighted averages
precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
    y_test, y_pred, average='macro', zero_division=0
)

precision_weighted, recall_weighted, f1_weighted, _ = precision_recall_fscore_support(
    y_test, y_pred, average='weighted', zero_division=0
)

print("="*80)
print("OVERALL METRICS")
print("="*80)
print(f"\nüìä Overall Accuracy: {overall_accuracy:.4f} ({overall_accuracy*100:.2f}%)")
print(f"\nüìà Macro Averages (unweighted):")
print(f"   Precision: {precision_macro:.4f}")
print(f"   Recall: {recall_macro:.4f}")
print(f"   F1-Score: {f1_macro:.4f}")
print(f"\nüìà Weighted Averages (by support):")
print(f"   Precision: {precision_weighted:.4f}")
print(f"   Recall: {recall_weighted:.4f}")
print(f"   F1-Score: {f1_weighted:.4f}")
print("="*80)

# Check if target met
if overall_accuracy >= 0.95:
    print(f"\n‚úÖ SUCCESS: Target accuracy (>95%) achieved!")
else:
    print(f"\n‚ö†Ô∏è  Target accuracy (>95%) not achieved.")
    print(f"   Gap: {(0.95 - overall_accuracy)*100:.2f}%")

## 6. Per-Class Metrics


In [None]:
# Create per-class metrics DataFrame
metrics_df = pd.DataFrame({
    'Class': [label_mapping[i] for i in range(len(precision))],
    'Precision': precision,
    'Recall': recall,
    'F1-Score': f1,
    'Support': support
})

# Sort by F1-Score
metrics_df = metrics_df.sort_values('F1-Score', ascending=False)

print("\nüìä Per-Class Metrics (sorted by F1-Score):")
display(metrics_df)

# Check classes below threshold
threshold = 0.85
low_f1_classes = metrics_df[metrics_df['F1-Score'] < threshold]

if len(low_f1_classes) > 0:
    print(f"\n‚ö†Ô∏è  Classes with F1-Score < {threshold}:")
    display(low_f1_classes)
else:
    print(f"\n‚úÖ All classes have F1-Score >= {threshold}!")

## 7. Confusion Matrix


In [None]:
# Calculate confusion matrix
cm = confusion_matrix(y_test, y_pred)

print(f"üìä Confusion Matrix shape: {cm.shape}")
print(f"   Diagonal sum (correct predictions): {np.trace(cm):,}")
print(f"   Off-diagonal sum (misclassifications): {cm.sum() - np.trace(cm):,}")

### 7.1 Visualize Confusion Matrix


In [None]:
# Create confusion matrix heatmap
plt.figure(figsize=(20, 18))

# Normalize confusion matrix for better visualization
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Create heatmap
sns.heatmap(cm_normalized, annot=False, fmt='.2f', cmap='Blues', 
            xticklabels=[label_mapping[i] for i in range(len(cm))],
            yticklabels=[label_mapping[i] for i in range(len(cm))],
            cbar_kws={'label': 'Normalized Count'})

plt.title('Confusion Matrix (Normalized)\nIoT Attack Detection - 34 Classes', 
         fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Predicted Label', fontsize=14, fontweight='bold')
plt.ylabel('True Label', fontsize=14, fontweight='bold')
plt.xticks(rotation=90, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()

# Save figure
output_path = '../Output/metrics/confusion_matrix.png'
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"üíæ Confusion matrix saved to: {output_path}")

plt.show()

## 8. Training History Visualization


In [None]:
# Load training history
history_path = '../Output/metrics/training_history.json'
with open(history_path, 'r') as f:
    training_history = json.load(f)

# Extract data
rounds = training_history['history']['round']
accuracy = training_history['history']['accuracy']
loss = training_history['history']['loss']

# Create plots
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Plot accuracy
axes[0].plot(rounds, accuracy, marker='o', linewidth=2.5, markersize=7, color='#2E86AB')
axes[0].axhline(y=0.95, color='red', linestyle='--', linewidth=2, label='Target (95%)', alpha=0.7)
axes[0].fill_between(rounds, 0, accuracy, alpha=0.2, color='#2E86AB')
axes[0].set_title('Global Model Accuracy vs Round', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Round', fontsize=12, fontweight='bold')
axes[0].set_ylabel('Accuracy', fontsize=12, fontweight='bold')
axes[0].set_ylim([0, 1.05])
axes[0].grid(True, alpha=0.3, linestyle='--')
axes[0].legend(fontsize=11)

# Plot loss
axes[1].plot(rounds, loss, marker='o', linewidth=2.5, markersize=7, color='#F18F01')
axes[1].fill_between(rounds, 0, loss, alpha=0.2, color='#F18F01')
axes[1].set_title('Global Model Loss vs Round', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Round', fontsize=12, fontweight='bold')
axes[1].set_ylabel('Loss', fontsize=12, fontweight='bold')
axes[1].grid(True, alpha=0.3, linestyle='--')

plt.tight_layout()

# Save figure
output_path = '../Output/metrics/accuracy_plot.png'
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"üíæ Training curves saved to: {output_path}")

plt.show()

## 9. Per-Class F1-Score Visualization


In [None]:
# Create F1-Score bar chart
plt.figure(figsize=(18, 8))

# Prepare data
class_names = [label_mapping[i] for i in range(len(f1))]
colors = ['#27AE60' if score >= 0.85 else '#E74C3C' for score in f1]

# Create bar chart
bars = plt.bar(range(len(f1)), f1, color=colors, edgecolor='black', linewidth=0.5)

# Add threshold line
plt.axhline(y=0.85, color='red', linestyle='--', linewidth=2, 
           label='Threshold (0.85)', alpha=0.7)

# Customize plot
plt.title('F1-Score per Attack Class\n(Green: ‚â•0.85 | Red: <0.85)', 
         fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Attack Class', fontsize=14, fontweight='bold')
plt.ylabel('F1-Score', fontsize=14, fontweight='bold')
plt.xticks(range(len(f1)), class_names, rotation=90, ha='right')
plt.ylim([0, 1.05])
plt.grid(axis='y', alpha=0.3, linestyle='--')
plt.legend(fontsize=12)
plt.tight_layout()

# Save figure
output_path = '../Output/metrics/f1_scores_per_class.png'
plt.savefig(output_path, dpi=300, bbox_inches='tight')
print(f"üíæ F1-Score chart saved to: {output_path}")

plt.show()

# Print summary
num_above_threshold = np.sum(f1 >= 0.85)
num_below_threshold = np.sum(f1 < 0.85)

print(f"\nüìä F1-Score Summary:")
print(f"   Classes with F1 ‚â• 0.85: {num_above_threshold}/{len(f1)} ({num_above_threshold/len(f1)*100:.1f}%)")
print(f"   Classes with F1 < 0.85: {num_below_threshold}/{len(f1)} ({num_below_threshold/len(f1)*100:.1f}%)")

## 10. Export Comprehensive Metrics Report


In [None]:
# Prepare comprehensive metrics report
metrics_report = {
    'overall_metrics': {
        'accuracy': float(overall_accuracy),
        'precision_macro': float(precision_macro),
        'recall_macro': float(recall_macro),
        'f1_macro': float(f1_macro),
        'precision_weighted': float(precision_weighted),
        'recall_weighted': float(recall_weighted),
        'f1_weighted': float(f1_weighted)
    },
    'per_class_metrics': {},
    'summary': {
        'total_test_samples': int(len(y_test)),
        'num_classes': int(len(label_mapping)),
        'classes_above_f1_threshold': int(num_above_threshold),
        'classes_below_f1_threshold': int(num_below_threshold),
        'target_accuracy_met': bool(overall_accuracy >= 0.95),
        'all_classes_above_threshold': bool(num_below_threshold == 0)
    },
    'confusion_matrix': cm.tolist()
}

# Add per-class metrics
for i in range(len(precision)):
    class_name = label_mapping[i]
    metrics_report['per_class_metrics'][class_name] = {
        'class_id': int(i),
        'precision': float(precision[i]),
        'recall': float(recall[i]),
        'f1_score': float(f1[i]),
        'support': int(support[i])
    }

# Save to JSON
report_path = '../Output/metrics/metrics_report.json'
with open(report_path, 'w') as f:
    json.dump(metrics_report, f, indent=2)

print(f"üíæ Comprehensive metrics report saved to: {report_path}")
print(f"\n‚úÖ Report includes:")
print(f"   - Overall metrics (accuracy, precision, recall, F1)")
print(f"   - Per-class metrics for all {len(label_mapping)} classes")
print(f"   - Confusion matrix")
print(f"   - Summary statistics")

## 11. Generate Classification Report


In [None]:
# Generate sklearn classification report
target_names = [label_mapping[i] for i in range(len(label_mapping))]
report = classification_report(y_test, y_pred, target_names=target_names, zero_division=0)

print("="*80)
print("CLASSIFICATION REPORT")
print("="*80)
print(report)

# Save to text file
report_txt_path = '../Output/metrics/classification_report.txt'
with open(report_txt_path, 'w') as f:
    f.write("CLASSIFICATION REPORT\n")
    f.write("="*80 + "\n")
    f.write(report)

print(f"\nüíæ Classification report saved to: {report_txt_path}")

## 12. Final Summary


In [None]:
print("="*80)
print("MODEL EVALUATION SUMMARY")
print("="*80)

print(f"\nüìä Performance Metrics:")
print(f"   Overall Accuracy: {overall_accuracy*100:.2f}%")
print(f"   Macro F1-Score: {f1_macro:.4f}")
print(f"   Weighted F1-Score: {f1_weighted:.4f}")

print(f"\nüéØ Target Achievement:")
if overall_accuracy >= 0.95:
    print(f"   ‚úÖ Accuracy target (>95%): ACHIEVED")
else:
    print(f"   ‚ùå Accuracy target (>95%): NOT ACHIEVED (Gap: {(0.95-overall_accuracy)*100:.2f}%)")

if num_below_threshold == 0:
    print(f"   ‚úÖ F1-Score target (>0.85 for all classes): ACHIEVED")
else:
    print(f"   ‚ö†Ô∏è  F1-Score target: {num_below_threshold} classes below 0.85")

print(f"\nüìÅ Generated Files:")
output_files = [
    '../Output/metrics/confusion_matrix.png',
    '../Output/metrics/accuracy_plot.png',
    '../Output/metrics/f1_scores_per_class.png',
    '../Output/metrics/metrics_report.json',
    '../Output/metrics/classification_report.txt'
]

for file_path in output_files:
    if os.path.exists(file_path):
        file_size = os.path.getsize(file_path) / 1024  # KB
        print(f"   ‚úì {os.path.basename(file_path)} ({file_size:.2f} KB)")

print(f"\nüíæ All Deliverables for Web App:")
deliverables = [
    '../Output/models/global_model.h5',
    '../Output/models/scaler.pkl',
    '../Output/models/label_encoder.pkl',
    '../Output/models/labels.json'
]

for file_path in deliverables:
    if os.path.exists(file_path):
        file_size = os.path.getsize(file_path) / 1024  # KB
        if file_path.endswith('.h5'):
            file_size = file_size / 1024  # MB for model
            print(f"   ‚úì {os.path.basename(file_path)} ({file_size:.2f} MB)")
        else:
            print(f"   ‚úì {os.path.basename(file_path)} ({file_size:.2f} KB)")

print(f"\n‚úÖ MODEL EVALUATION COMPLETE!")
print(f"\nüìù Next steps:")
print(f"   1. Review all visualizations and metrics")
print(f"   2. Include plots in thesis report")
print(f"   3. Use deliverables for Web App integration")
print("="*80)