In [1]:
import sys
sys.path.append('../')

import numpy as np
import joblib
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
import json
import time

print("="*70)
print("MULTI-CLASS IDS - MODEL TRAINING")
print("="*70)

print("\nLoading preprocessed data...")
X_train = np.load('../data/processed_multiclass/X_train_scaled.npy')
X_test = np.load('../data/processed_multiclass/X_test_scaled.npy')
y_train = np.load('../data/processed_multiclass/y_train.npy')
y_test = np.load('../data/processed_multiclass/y_test.npy')

preprocessor = joblib.load('../models/preprocessor_multiclass.pkl')
label_encoder = preprocessor['label_encoder']
classes = preprocessor['classes']

print(f" Training set: {X_train.shape}")
print(f" Test set: {X_test.shape}")
print(f" Number of classes: {len(classes)}")

print(f"\nClass distribution in training set:")
for i, cls in enumerate(classes):
    count = (y_train == i).sum()
    percent = (count / len(y_train)) * 100
    print(f"  {i:2d}. {cls:<35} {count:>8,} ({percent:>5.2f}%)")

print("\n" + "="*70)
print("Training Random Forest Classifier...")
print("="*70)
print("Configuration:")
print("  - 100 trees")
print("  - Max depth: 20")
print("  - Using all CPU cores")

model = RandomForestClassifier(
    n_estimators=100,
    max_depth=20,
    random_state=42,
    n_jobs=-1,
    verbose=1,
    class_weight='balanced'
)

start_time = time.time()
model.fit(X_train, y_train)
training_time = time.time() - start_time

print(f"\n Training completed in {training_time:.2f} seconds ({training_time/60:.2f} minutes)")

print("\nMaking predictions on test set...")
start_time = time.time()
y_pred = model.predict(X_test)
inference_time = time.time() - start_time

accuracy = accuracy_score(y_test, y_pred)

print(f"\n{'='*70}")
print(f"RESULTS")
print(f"{'='*70}")
print(f"Overall Accuracy: {accuracy*100:.2f}%")
print(f"Inference Time: {inference_time:.2f}s for {len(y_test):,} samples")
print(f"Per-sample: {(inference_time/len(y_test))*1000:.4f}ms")

print(f"\n{'='*70}")
print("Per-Class Performance:")
print(f"{'='*70}")
report_dict = classification_report(y_test, y_pred, target_names=classes, output_dict=True)

print(f"\n{'Class':<35} {'Precision':<12} {'Recall':<12} {'F1-Score':<12} {'Support':<10}")
print("-" * 85)
for cls in classes:
    if cls in report_dict:
        metrics = report_dict[cls]
        print(f"{cls:<35} {metrics['precision']:<12.4f} {metrics['recall']:<12.4f} {metrics['f1-score']:<12.4f} {int(metrics['support']):<10,}")

# Confusion matrix analysis
cm = confusion_matrix(y_test, y_pred)
print(f"\n{'='*70}")
print("Attack Detection Performance:")
print(f"{'='*70}")

for i, cls in enumerate(classes):
    total = (y_test == i).sum()
    if total > 0:
        correct = cm[i, i]
        accuracy_cls = (correct / total) * 100
        print(f"{cls:<35} {correct:>6,}/{total:>6,} ({accuracy_cls:>6.2f}%)")

print(f"\n{'='*70}")
print("Saving model...")
print(f"{'='*70}")
joblib.dump(model, '../models/random_forest_multiclass.pkl')
print(" Saved model to: models/random_forest_multiclass.pkl")

results = {
    'accuracy': float(accuracy),
    'training_time_seconds': float(training_time),
    'inference_time_seconds': float(inference_time),
    'per_sample_ms': float((inference_time/len(y_test))*1000),
    'classes': classes,
    'classification_report': classification_report(y_test, y_pred, target_names=classes, output_dict=True)
}

with open('../models/multiclass_results.json', 'w') as f:
    json.dump(results, f, indent=2)
print("✓ Saved results to: models/multiclass_results.json")

# Feature importance
print(f"\n{'='*70}")
print("Top 15 Most Important Features for Attack Detection:")
print(f"{'='*70}")
feature_columns = preprocessor['feature_columns']
importances = model.feature_importances_
indices = np.argsort(importances)[::-1][:15]

for rank, idx in enumerate(indices, 1):
    print(f"{rank:2d}. {feature_columns[idx]:<40} {importances[idx]:.4f}")

print(f"\n{'='*70}")
print("TRAINING COMPLETE!")
print(f"{'='*70}")

MULTI-CLASS IDS - MODEL TRAINING

Loading preprocessed data...
 Training set: (2016638, 78)
 Test set: (504160, 78)
 Number of classes: 15

Class distribution in training set:
   0. BENIGN                              1,676,045 (83.11%)
   1. Bot                                    1,558 ( 0.08%)
   2. DDoS                                 102,411 ( 5.08%)
   3. DoS GoldenEye                          8,229 ( 0.41%)
   4. DoS Hulk                             138,277 ( 6.86%)
   5. DoS Slowhttptest                       4,182 ( 0.21%)
   6. DoS slowloris                          4,308 ( 0.21%)
   7. FTP-Patator                            4,745 ( 0.24%)
   8. Heartbleed                                 9 ( 0.00%)
   9. Infiltration                              29 ( 0.00%)
  10. PortScan                              72,555 ( 3.60%)
  11. SSH-Patator                            2,575 ( 0.13%)
  12. Web Attack - Brute Force               1,176 ( 0.06%)
  13. Web Attack - Sql Injection           

[Parallel(n_jobs=-1)]: Using backend ThreadingBackend with 10 concurrent workers.
[Parallel(n_jobs=-1)]: Done  30 tasks      | elapsed:   25.5s
[Parallel(n_jobs=-1)]: Done 100 out of 100 | elapsed:  1.2min finished
[Parallel(n_jobs=10)]: Using backend ThreadingBackend with 10 concurrent workers.



 Training completed in 75.18 seconds (1.25 minutes)

Making predictions on test set...


[Parallel(n_jobs=10)]: Done  30 tasks      | elapsed:    0.3s
[Parallel(n_jobs=10)]: Done 100 out of 100 | elapsed:    0.7s finished



RESULTS
Overall Accuracy: 99.47%
Inference Time: 0.76s for 504,160 samples
Per-sample: 0.0015ms

Per-Class Performance:

Class                               Precision    Recall       F1-Score     Support   
-------------------------------------------------------------------------------------
BENIGN                              0.9999       0.9942       0.9970       419,012   
Bot                                 0.1598       0.9821       0.2748       390       
DDoS                                1.0000       0.9998       0.9999       25,603    
DoS GoldenEye                       0.9927       0.9942       0.9934       2,057     
DoS Hulk                            0.9937       0.9995       0.9966       34,569    
DoS Slowhttptest                    0.9914       0.9962       0.9938       1,046     
DoS slowloris                       0.9953       0.9879       0.9916       1,077     
FTP-Patator                         1.0000       0.9983       0.9992       1,186     
Heartbleed        