In [None]:
import torch.nn as nn
import torch
import pickle
from sklearn.metrics import accuracy_score, f1_score, cohen_kappa_score, confusion_matrix, precision_score, recall_score
import pandas as pd
import torch
from tqdm import tqdm
import numpy as np
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns



In [None]:
file_path = './test_data.pkl'
with open(file_path, 'rb') as f:
    loaded_test_data = pickle.load(f)

test_loader = loaded_test_data["test_loader"]
label_to_index = loaded_test_data["label_to_index"]
attack_types_list = loaded_test_data["attack_types_list"]

In [None]:
class LSTMClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_classes):
        super(LSTMClassifier, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        out, _ = self.lstm(x)
        feature_vector = out[:,-1,:]
        out = self.fc2(feature_vector)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc3(out) ## 128
        out = self.bn2(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc4(out)

        return out, feature_vector

    def get_features(self, x):
        out, _ = self.lstm(x)
        return out[:,-1,:]

input_dim = 10  
hidden_dim = 256 
num_layers = 4  
learning_rate = 0.001

In [None]:
class FeatureExtraction_LSTMClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, num_classes):
        super(FeatureExtraction_LSTMClassifier, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        
        self.fc2 = nn.Linear(hidden_dim, num_classes)
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)  # Dropout with 50% drop probability
        
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
    def forward(self, x):
        out, _ = self.lstm(x)
        feature_vector = out[:,-1,:]
        out = self.fc2(feature_vector)
        return out, feature_vector

    def get_features(self, x):
        out, _ = self.lstm(x)
        return out[:,-1,:]


In [None]:
class MLP(nn.Module):
    def __init__(self, num_classes):
        super(MLP, self).__init__()
        self.input_size = 256
        self.hidden_size = 256

        self.fc1 = nn.Linear(self.input_size, self.hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(self.hidden_size, self.hidden_size//2)
        self.fc3 = nn.Linear(self.hidden_size//2, num_classes)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        return out

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

lstm_model = LSTMClassifier(input_dim, hidden_dim, num_layers, 20).to(device)
hiera_model_save_path = './models/'
num_classes = len(attack_types_list) 
hierarchical_models = defaultdict(list)

In [None]:
import pickle

def load_saved_data(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data

file_path = './models/hierachical_models.pkl'
loaded_data = load_saved_data(file_path)

hierarchical_models = loaded_data["hierarchical_models"]

In [None]:
import torch
from collections import defaultdict

class HierarchicalClassifier:
    def __init__(self, device, hierarchical_models, attack_label_to_index):
        self.device = device
        self.hierarchical_models = hierarchical_models
        self.attack_label_to_index = attack_label_to_index
        
        self.models_cache = {}
        for group_name, info in self.hierarchical_models.items():
            model_info = info[0] 
            lbl2idx = model_info['label_to_index']
            idx2lbl = {v: k for k, v in lbl2idx.items()}
            
            self.models_cache[group_name] = {
                'model': model_info['model'],
                'lbl2idx': lbl2idx,
                'idx2lbl': idx2lbl,
                'is_one_class': model_info['is_one_class'],
            }
    
    def classify(self, batch_seq):
        batch_seq = batch_seq.to(self.device)
        combined_info = self.models_cache['combined']
        combined_model = combined_info['model']
        combined_model.eval()

        with torch.no_grad():
            pred_group_scores, feature_vectors = combined_model(batch_seq)
        _, group_nums = torch.max(pred_group_scores, dim=1)
        batch_size = batch_seq.size(0)
        batch_predictions = [None] * batch_size

        group_to_indices = defaultdict(list)
        for i, gnum in enumerate(group_nums):
            group_to_indices[gnum.item()].append(i)


        for gnum, indices in group_to_indices.items():
            group_name = combined_info['idx2lbl'][gnum]

            group_info = self.models_cache[group_name]
            group_model = group_info['model']
            group_lbl2idx = group_info['lbl2idx']
            group_idx2lbl = group_info['idx2lbl']
            group_is_one_class = group_info['is_one_class']

            if group_is_one_class:
                only_label = list(group_lbl2idx.keys())[0]
                only_label_idx = self.attack_label_to_index[only_label]
                for i in indices:
                    batch_predictions[i] = only_label_idx
            else:
                group_features = feature_vectors[indices]  # (len(indices), feature_dim)
                group_model.eval()
                with torch.no_grad():
                    pred_scores = group_model(group_features)
                _, pred_label_idx = torch.max(pred_scores, dim=1)

                for i, lbl_idx in zip(indices, pred_label_idx):
                    lbl_idx_int = lbl_idx.item()
                    final_label_str = group_idx2lbl[lbl_idx_int]
                    final_label_idx = self.attack_label_to_index[final_label_str]
                    batch_predictions[i] = final_label_idx

        return batch_predictions
classifier = HierarchicalClassifier(device, hierarchical_models, label_to_index)

In [None]:

def atk_find_key(dictionary, values):
    labels = []
    for i in values:
        for k,v in dictionary.items():
            if v == i:
                labels.append(k)
    return labels

def get_metrics_from_cm(cm, class_names):
    true_labels = cm.sum(axis=1)
    predicted_labels = cm.sum(axis=0)
    
    precision = np.diag(cm) / predicted_labels
    recall = np.diag(cm) / true_labels
    f1_scores = 2 * (precision * recall) / (precision + recall)
    
    df = pd.DataFrame({
        'Label': class_names,
        'Precision': precision,
        'Recall': recall,
        'F1 Score': f1_scores
    })
    
    return df


def evaluate_model_batchwise(data_loader, classifier, device):
    all_predictions = []
    all_true_labels = []
    total_samples = 0

    with torch.no_grad():
        for batch_seq, batch_labels in tqdm(data_loader, desc="Evaluating multi-class"):
            batch_seq = batch_seq.to(device)
            preds = classifier.classify(batch_seq)
            if isinstance(preds, torch.Tensor):
                preds = preds.cpu().numpy()
            all_predictions.extend(preds)
            all_true_labels.extend(batch_labels.cpu().numpy())
            total_samples += batch_seq.size(0)


    acc = accuracy_score(all_true_labels, all_predictions)
    f1 = f1_score(all_true_labels, all_predictions, average='macro')
    precision_macro = precision_score(all_true_labels, all_predictions, average='macro')
    recall_macro = recall_score(all_true_labels, all_predictions, average='macro')
    cm_multi = confusion_matrix(all_true_labels, all_predictions)

    return acc, f1, precision_macro, recall_macro, cm_multi

In [None]:
acc, f1, precision_macro, recall_macro, cm= evaluate_model_batchwise(test_loader, classifier, device)
print(f"Accuracy: {acc:.4f}, F1: {f1:.4f}")
print(f"precision: {precision_macro:.4f}, recall: {recall_macro:.4f}")

In [None]:
df = get_metrics_from_cm(cm, attack_types_list)
print(df)