Загрузим тестовые данные

In [127]:
import os
from tqdm import tqdm
import torch

from src.models.pl_datamodule import MyDataModule
from src.models.bert_models import TransformersTextClassifier

In [23]:
import warnings
from transformers import logging


warnings.filterwarnings('ignore')
logging.set_verbosity_error()

Для оценки выберем следующие модели-трансформеры из библиотеки huggingface.
* [distilbert-base-uncased](https://huggingface.co/distilbert-base-uncased)
* ["smallbenchnlp/bert-small"]("smallbenchnlp/bert-small")
* [prajjwal1/bert-mini](https://huggingface.co/prajjwal1/bert-mini)
* [prajjwal1/bert-tiny](https://huggingface.co/prajjwal1/bert-tiny)

Вызовем тестовые данные для проверки, которые наша модель не видела при обучении

In [3]:
os.chdir('..')

Имена и расположение наших переменных

In [172]:
data_path = 'data/processed/processed_df.csv'
model_names  = ['distilbert-base-uncased',  "smallbenchnlp/bert-small",
         'prajjwal1/bert-mini', 'prajjwal1/bert-tiny']

model_paths = ["models\distil_bert\distilbert-base-uncased.ckpt", 
              r"models\smallbenchnlp\bert-small_1e-05_16\smallbenchnlp\bert-small.ckpt",
               r'models\bert-mini_1e-05_32\prajjwal1\bert-mini.ckpt',
                r'models\bert-tiny_1e-05_32\prajjwal1\bert-tiny.ckpt']
batch_sizes = [8, 16, 32, 64]

In [106]:
len(dm.test_dataloader())

1205

In [160]:
def test_prediction_and_labels(model_name,  model_path, batch_size):
    dm = MyDataModule(data_path, model_name, batch_size)
    dm.prepare_data()
    dm.setup('test')
    trained_model = TransformersTextClassifier(model_name)
    checkpoint = torch.load(model_path)
    trained_model.load_state_dict(checkpoint['state_dict'])
    trained_model.eval()
    trained_model.freeze()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    trained_model = trained_model.to(device)
    flag = False

    labels = []
    for item in tqdm(dm.test_dataloader()):
        _, prediction = trained_model(
        item["input_ids"].to(device),
        item["attention_mask"].to(device)
      )
        prediction = torch.sigmoid(prediction)
        if not flag:
            predictions = prediction
            labels = item["labels"].int()
            flag = True
        else:
            predictions = torch.vstack((predictions, prediction))
            labels = torch.vstack((labels, item["labels"].int()))

    predictions = predictions.detach().cpu()
    labels = labels.detach().cpu()
    return predictions, labels

In [148]:
from torchmetrics import Accuracy
from torchmetrics.classification import MultilabelAUROC
from sklearn.metrics import classification_report
import numpy as np

In [140]:
def accuracy_score(predictions, labels):
    accuracy = Accuracy(task="multilabel", num_labels=20)
    score = accuracy(predictions, labels)
    return score

In [141]:
def multilabel_auroc(predictions, labels):
    ml_auroc = MultilabelAUROC(num_labels=20, average="macro", thresholds=None)
    score = ml_auroc(predictions, labels)
    return score

In [142]:
def multilabel_auroc(predictions, labels):
    ml_auroc = MultilabelAUROC(num_labels=20, average="macro", thresholds=None)
    score = ml_auroc(predictions, labels)
    return score

In [149]:
def mulilabel_classification_report(predictions, labels):
    y_pred = predictions.numpy()
    y_true = labels.numpy()
    upper, lower = 1, 0
    y_pred = np.where(y_pred > 0.5, upper, lower)
    print(classification_report(
      y_true,
      y_pred,
      target_names=label_namre,
      zero_division=0
    ))

## 1. Distilbert_base_uncased

In [138]:
predictions, labels = test_prediction_and_labels(model_names[0], model_paths[0], batch_size[0])

(64388, 21) processed_df.shape


100%|██████████████████████████████████████████████████████████████████████████████| 1205/1205 [04:09<00:00,  4.83it/s]


In [143]:
accuracy_score(predictions, labels)

tensor(0.9717)

In [144]:
multilabel_auroc(predictions, labels)

tensor(0.9629)

In [151]:
mulilabel_classification_report(predictions, labels)

                                        precision    recall  f1-score   support

                   encodded_label_CASB       0.88      0.80      0.84       522
                    encodded_label_EDR       0.93      0.69      0.80       667
                    encodded_label_MDR       0.96      0.80      0.87       522
                    encodded_label_NDR       0.96      0.85      0.90       209
                   encodded_label_NGFW       0.78      0.80      0.79       243
                   encodded_label_SASE       0.89      0.81      0.85       522
                   encodded_label_SIEM       0.88      0.64      0.74       522
                   encodded_label_SOAR       0.96      0.75      0.84       367
       encodded_label_anti-counterfeit       0.82      0.86      0.84       522
    encodded_label_application_control       0.80      0.89      0.84       522
           encodded_label_atm_security       0.85      0.84      0.85       522
               encodded_label_honeypot 

## 2. Smallbenchnlp/bert-small

In [161]:
predictions, labels = test_prediction_and_labels(model_names[1], model_paths[1], batch_size[1])

(64388, 21) processed_df.shape


100%|████████████████████████████████████████████████████████████████████████████████| 603/603 [01:07<00:00,  8.91it/s]


In [162]:
accuracy_score(predictions, labels)

tensor(0.9434)

In [163]:
multilabel_auroc(predictions, labels)

tensor(0.8396)

In [164]:
mulilabel_classification_report(predictions, labels)

                                        precision    recall  f1-score   support

                   encodded_label_CASB       0.00      0.00      0.00       522
                    encodded_label_EDR       0.00      0.00      0.00       667
                    encodded_label_MDR       0.00      0.00      0.00       522
                    encodded_label_NDR       0.00      0.00      0.00       209
                   encodded_label_NGFW       0.00      0.00      0.00       243
                   encodded_label_SASE       1.00      0.06      0.12       522
                   encodded_label_SIEM       0.00      0.00      0.00       522
                   encodded_label_SOAR       0.00      0.00      0.00       367
       encodded_label_anti-counterfeit       1.00      0.26      0.42       522
    encodded_label_application_control       0.96      0.50      0.66       522
           encodded_label_atm_security       0.00      0.00      0.00       522
               encodded_label_honeypot 

## 3. prajjwal1/bert-mini

In [167]:
predictions, labels = test_prediction_and_labels(model_names[2], model_paths[2], batch_size[2])

(64388, 21) processed_df.shape


100%|████████████████████████████████████████████████████████████████████████████████| 302/302 [00:28<00:00, 10.70it/s]


In [168]:
accuracy_score(predictions, labels)

tensor(0.9555)

In [169]:
multilabel_auroc(predictions, labels)

tensor(0.9160)

In [170]:
mulilabel_classification_report(predictions, labels)

                                        precision    recall  f1-score   support

                   encodded_label_CASB       0.95      0.67      0.78       522
                    encodded_label_EDR       0.82      0.35      0.49       667
                    encodded_label_MDR       0.95      0.68      0.79       522
                    encodded_label_NDR       0.97      0.55      0.70       209
                   encodded_label_NGFW       0.88      0.36      0.51       243
                   encodded_label_SASE       0.85      0.70      0.77       522
                   encodded_label_SIEM       0.93      0.35      0.51       522
                   encodded_label_SOAR       0.93      0.58      0.72       367
       encodded_label_anti-counterfeit       0.91      0.72      0.80       522
    encodded_label_application_control       0.85      0.67      0.75       522
           encodded_label_atm_security       0.81      0.69      0.74       522
               encodded_label_honeypot 

## 4. prajjwal1/bert-mini

In [173]:
predictions, labels = test_prediction_and_labels(model_names[3], model_paths[3], batch_size[3])

(64388, 21) processed_df.shape


100%|████████████████████████████████████████████████████████████████████████████████| 151/151 [00:27<00:00,  5.41it/s]


In [174]:
accuracy_score(predictions, labels)

tensor(0.9383)

In [175]:
multilabel_auroc(predictions, labels)

tensor(0.8529)

In [176]:
mulilabel_classification_report(predictions, labels)

                                        precision    recall  f1-score   support

                   encodded_label_CASB       0.00      0.00      0.00       522
                    encodded_label_EDR       0.00      0.00      0.00       667
                    encodded_label_MDR       0.00      0.00      0.00       522
                    encodded_label_NDR       0.00      0.00      0.00       209
                   encodded_label_NGFW       0.00      0.00      0.00       243
                   encodded_label_SASE       1.00      0.00      0.00       522
                   encodded_label_SIEM       0.00      0.00      0.00       522
                   encodded_label_SOAR       0.00      0.00      0.00       367
       encodded_label_anti-counterfeit       1.00      0.28      0.44       522
    encodded_label_application_control       1.00      0.52      0.68       522
           encodded_label_atm_security       1.00      0.01      0.02       522
               encodded_label_honeypot 

### На данном этапе исследования сложно выбрать модель для запуска в продакшн.

Модель Distilbert_base_uncased показала наибольшую точность, удовлетворительные результаты на несбалансированных выборках, кроме того она обучалась на 5 эпохах, и на 5 эпохе наблюдалась тенденция к снижению значения функции потерь на валидационной и тренировочной выборке, так что при дальнейшем обучении, аугментации данных возможно еще улучшить ее результаты. Однако она максимальна затратна по производительности, возможно улучшение производительности инференса модели с помощью квантизации и др техник.


Модель 3 так же показала неплохие результаты, и имеет смысл взять ее для дальнейших исследований.

### Что еще можно сделать
1. Не протестированы простейшие модели типа TF-IDF
2. Не протестированы RNN модели
3. Провести аугментацию данных на редких классах