<a href="https://colab.research.google.com/github/jhen-fang/P_Project-24/blob/main/P_Project_24_final_fn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **1. 專案名稱：使用網路威脅資料集實作文本分類**

- **姓名：蔡甄芳**
- **系級：資管四乙**
- **學號：109306056**
- **GitHUB codes: https://github.com/jhen-fang/P_Project-24/blob/main/P_Project_24_final.ipynb**
- **Colab link: https://colab.research.google.com/drive/13K60edQIt0uHyqfikX1O7napKgZV93lS?usp=sharing**

## **2. 資料集介紹**

##### (1) **資料集名稱：Cyber Threat Dataset: Network, Text & Relation**

##### (2) **資料集來源：https://www.kaggle.com/datasets/ramoliyafenil/text-based-cyber-threat-detection**

##### (3) **資料集簡介：這個資料集包含了網路流量資料( network traffic data), 文字內容( textual content), 實體關係( entity realationships)等等, 可用來檢測、診斷和減輕網路威脅。**

##### **(4) 資料集欄位：**
    - id: 資料集中每個 instance 的 identifier。
    - text: 透過網路傳輸的文字內容，如：電子郵件、訊息或網路流量負載。並包含潛在的網路威脅描述。
    - Entries: JSON 清單，包含以下
        - sender_id
        - label : 識別出的網路威脅或攻擊模式
        - start_offset
        - end_offset
        - receive_ids
    - relations: 一個 tuples 表示實體關係，包含一對實體 IDs ( source and target )
    - diagnosis: 對已經識別出的網路威脅的描述及診斷，提供見解。
    - solutions: 針對網路威脅提供解決方案或緩解策略的描述。

---


*F. Ramoliya, R. Kakkar, R. Gupta, S. Tanwar and S. Agrawal, "SEAM: Deep Learning-based Secure Message Exchange Framework For Autonomous EVs," 2023 IEEE Globecom Workshops (GC Wkshps), Kuala Lumpur, Malaysia, 2023, pp. 80-85, doi: 10.1109/GCWkshps58843.2023.10465168.*

## **3. 專案目的：**

#### **網路威脅偵測：根據 Text ( textual content 以及 network traffic data ) 分類網路威脅 ( Entries: label )**

    - Pipeline:
        - text Classification
        - zero-shot-classification
    - 目標：利用網路威脅描述(text) 分類出攻擊模式或潛在威脅(label)



## **4. 專案架構：**

#### **(1) 文本前處理**
#### **(2) 預訓練模型選擇與比較**
#### **(3) 模型 fine-tuning**
#### **(4) 性能比較**
#### **(5) Downstream task**


## **5. 程式碼實作 Downstream Task**

In [1]:
!pip install transformers pandas numpy matplotlib seaborn kaggle datasets evaluate transformers[torch]


Collecting datasets
  Downloading datasets-2.19.2-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.1/542.1 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting evaluate
  Downloading evaluate-0.4.2-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting requests (from transformers)
  Downloading requests-2.32.3-py3-none-any.whl (64 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.9/64.9 kB[0m [31m6.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━

In [2]:
!pip install sentence_transformers

Collecting sentence_transformers
  Downloading sentence_transformers-3.0.1-py3-none-any.whl (227 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/227.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m227.1/227.1 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: sentence_transformers
Successfully installed sentence_transformers-3.0.1


In [3]:
import pandas as pd
import numpy as np
import torch
import evaluate
import os
from google.colab import userdata
from datasets import Dataset, DatasetDict, load_metric
from torch import tensor
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    pipeline,
    AutoModelForQuestionAnswering,
    DefaultDataCollator,
    DistilBertTokenizer,
    DistilBertForSequenceClassification,
    GPT2Tokenizer,
    GPT2LMHeadModel,
    AutoModelForCausalLM,
    AutoTokenizer
    )
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
import re
from sklearn.feature_extraction.text import CountVectorizer
from gensim.models import Word2Vec, FastText
from sentence_transformers import SentenceTransformer
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from transformers import get_linear_schedule_with_warmup
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from sklearn.preprocessing import StandardScaler
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import RobertaTokenizer, RobertaForSequenceClassification
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
from transformers import GPT2Tokenizer, GPT2ForSequenceClassification
from datasets import load_dataset
from gensim.models import FastText
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import TrainerCallback
from sklearn.metrics import confusion_matrix


In [4]:
torch.multiprocessing.set_start_method('spawn', force=True)


In [5]:
# 下載資料集
api_key = userdata.get('kaggle_key')
username = userdata.get('kaggle_username')

os.environ['KAGGLE_USERNAME'] = username
os.environ['KAGGLE_KEY'] = api_key

!kaggle datasets download -d ramoliyafenil/text-based-cyber-threat-detection

# 解壓縮資料集
!unzip text-based-cyber-threat-detection.zip

Dataset URL: https://www.kaggle.com/datasets/ramoliyafenil/text-based-cyber-threat-detection
License(s): Apache 2.0
Downloading text-based-cyber-threat-detection.zip to /content
 77% 3.00M/3.91M [00:01<00:00, 3.43MB/s]
100% 3.91M/3.91M [00:01<00:00, 3.31MB/s]
Archive:  text-based-cyber-threat-detection.zip
  inflating: Cyber-Threat-Intelligence-Custom-Data_new_processed.csv  
  inflating: all.jsonl               
  inflating: cyber-threat-intelligence-splited_test.csv  
  inflating: cyber-threat-intelligence-splited_train.csv  
  inflating: cyber-threat-intelligence-splited_validate.csv  
  inflating: cyber-threat-intelligence_all.csv  
  inflating: test.jsonl              
  inflating: train.jsonl             
  inflating: validation.jsonl        


In [6]:
custom_data_new_processed = pd.read_csv('/content/Cyber-Threat-Intelligence-Custom-Data_new_processed.csv')
custom_data_new_processed.head()

Unnamed: 0,id,text,relations,diagnosis,solutions,id_1,label_1,start_offset_1,end_offset_1,id_2,label_2,start_offset_2,end_offset_2,id_3,label_3,start_offset_3,end_offset_3
0,249,A cybersquatting domain save-russia[.]today is...,"[{'from_id': 44658, 'id': 9, 'to_id': 44659, '...",The diagnosis is a cyber attack that involves ...,1. Implementing DNS filtering to block access ...,44656,attack-pattern,2,16,44657,url,24,43,44658.0,attack-pattern,57.0,68.0
1,14309,"Like the Android Maikspy, it first sends a not...","[{'from_id': 48531, 'id': 445, 'to_id': 48532,...",The diagnosis is that the entity identified as...,1. Implementing a robust anti-malware software...,48530,SOFTWARE,9,17,48531,malware,17,24,48532.0,Infrastucture,63.0,73.0
2,13996,While analyzing the technical details of this ...,"[{'from_id': 48781, 'id': 461, 'to_id': 48782,...",Diagnosis: APT37/Reaper/Group 123 is responsib...,1. Implementing advanced threat detection tech...,48781,threat-actor,188,194,48782,threat-actor,210,217,48783.0,threat-actor,220.0,229.0
3,13600,(Note that Flash has been declared end-of-life...,"[{'from_id': 51688, 'id': 1133, 'to_id': 51689...",The diagnosis is a malware infection. The enti...,1. Implementing a robust antivirus software th...,51687,TIME,62,79,51688,malware,207,215,51689.0,malware,247.0,258.0
4,14364,Figure 21. Connection of Maikspy variants to 1...,"[{'from_id': 51780, 'id': 1161, 'to_id': 44372...",The diagnosis is that Maikspy malware variants...,1. Implementing a robust firewall system that ...,51779,URL,163,191,51777,URL,70,93,51781.0,malware,120.0,127.0


In [7]:
# 挑出 label_1, text 作為文本分類的欄位
selected_data_for_cls = custom_data_new_processed[['label_1', 'text']].rename(columns={'label_1': 'label'})
selected_data_for_cls

Unnamed: 0,label,text
0,attack-pattern,A cybersquatting domain save-russia[.]today is...
1,SOFTWARE,"Like the Android Maikspy, it first sends a not..."
2,threat-actor,While analyzing the technical details of this ...
3,TIME,(Note that Flash has been declared end-of-life...
4,URL,Figure 21. Connection of Maikspy variants to 1...
...,...,...
471,malware,"Cyclops Blink, an advanced modular botnet that..."
472,location,Sofacy Group has been associated with many at...
473,Infrastucture,The plugin has been designed to drop multiple ...
474,threat-actor,We have uncovered a cyberespionage campaign be...


In [8]:
label_counts = selected_data_for_cls['label'].value_counts()
print(label_counts)

# 因為量非常少的種類很多，因此篩選出數量大於20的label
labels_over_20 = label_counts[label_counts > 20].index

# 使用這些 label 來過濾 data_for_cls
data_for_cls = selected_data_for_cls[selected_data_for_cls['label'].isin(labels_over_20)].reset_index(drop=True)
data_for_cls.to_csv("./project_data.csv")

label
malware           141
threat-actor       68
attack-pattern     44
identity           37
vulnerability      35
SOFTWARE           25
location           23
campaign           20
tools              19
TIME               17
FILEPATH           15
hash               14
Infrastucture       7
url                 6
URL                 3
REGISTRYKEY         1
IPV4                1
Name: count, dtype: int64


### **5-1. 文本前處理**

In [9]:
# 將 text 欄位轉為小寫
def clean_text(text):
    text = text.lower()
    return text

data_for_cls['clean_text'] = data_for_cls['text'].apply(clean_text)
data_for_cls

Unnamed: 0,label,text,clean_text
0,attack-pattern,A cybersquatting domain save-russia[.]today is...,a cybersquatting domain save-russia[.]today is...
1,SOFTWARE,"Like the Android Maikspy, it first sends a not...","like the android maikspy, it first sends a not..."
2,threat-actor,While analyzing the technical details of this ...,while analyzing the technical details of this ...
3,location,The source code of this framework is shared ac...,the source code of this framework is shared ac...
4,vulnerability,The CVE-2022-22965 vulnerability allows an att...,the cve-2022-22965 vulnerability allows an att...
...,...,...,...
368,malware,BIOPASS RAT Loader Backdoor.Win64.BIOPASS.A ...,biopass rat loader backdoor.win64.biopass.a ...
369,malware,"Cyclops Blink, an advanced modular botnet that...","cyclops blink, an advanced modular botnet that..."
370,location,Sofacy Group has been associated with many at...,sofacy group has been associated with many at...
371,threat-actor,We have uncovered a cyberespionage campaign be...,we have uncovered a cyberespionage campaign be...


### **5-2. 文本特徵提取**

**以下比較四種文本特徵提取方法**
- Bag of Words
- Word2Vec
- Multilintual SBERT
- FastText

In [10]:
# 1. Bag of Words
vectorizer = CountVectorizer()
X_bow = vectorizer.fit_transform(data_for_cls['clean_text'])

# 2. Word2Vec
sentences = [row.split() for row in data_for_cls['clean_text']]
word2vec_cbow = Word2Vec(sentences, vector_size=100, window=5, min_count=1, workers=4, sg=0)
word2vec_cbow.train(sentences, total_examples=len(sentences), epochs=10)

# 每個文本的平均向量
def vectorize_text(text, model):
    vectors = [model.wv[word] for word in text if word in model.wv]
    return np.mean(vectors, axis=0) if vectors else np.zeros(model.vector_size)

X_w2v = np.array([vectorize_text(text, word2vec_cbow) for text in sentences])

# 3. Multilingual SBERT
model_sbert = SentenceTransformer('all-MiniLM-L12-v2')
X_sbert = model_sbert.encode(data_for_cls['clean_text'])


sentences = [row.split() for row in data_for_cls['clean_text']]

# 4. FastText
fasttext_model = FastText(vector_size=100, window=5, min_count=1, workers=4)
# 構建詞彙表
fasttext_model.build_vocab(corpus_iterable=sentences)
fasttext_model.train(corpus_iterable=sentences, total_examples=len(sentences), epochs=10)
X_ft = np.array([vectorize_text(text, fasttext_model) for text in sentences])


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/615 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/133M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/352 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

**進行初步的模型訓練與評估**

In [11]:
def train_evaluate(X, y):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    model = LogisticRegression(random_state=42, max_iter=1000)
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)
    print(classification_report(y_test, y_pred))

# 訓練並評估每種向量化方法
print("Bag of Words Results:")
train_evaluate(X_bow, data_for_cls['label'])

print("Word2Vec Results:")
train_evaluate(X_w2v, data_for_cls['label'])

print("SBERT Results:")
train_evaluate(X_sbert, data_for_cls['label'])

print("FastText Results:")
train_evaluate(X_ft, data_for_cls['label'])

Bag of Words Results:
                precision    recall  f1-score   support

      SOFTWARE       1.00      0.20      0.33         5
attack-pattern       0.83      0.50      0.62        10
      identity       0.75      0.43      0.55         7
      location       0.00      0.00      0.00         3
       malware       0.67      0.76      0.71        37
  threat-actor       0.23      0.50      0.32         6
 vulnerability       0.71      0.71      0.71         7

      accuracy                           0.60        75
     macro avg       0.60      0.44      0.46        75
  weighted avg       0.66      0.60      0.60        75

Word2Vec Results:
                precision    recall  f1-score   support

      SOFTWARE       0.00      0.00      0.00         5
attack-pattern       0.00      0.00      0.00        10
      identity       0.00      0.00      0.00         7
      location       0.00      0.00      0.00         3
       malware       0.49      0.95      0.65        37
  th

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


**特徵提取結果評估：**

1. **Bag of Words (BoW)
準確率：0.60**

    BoW在所有方法中表現最均衡，尤其是對於malware和vulnerability類別表現良好。這可能是因為BoW方法能夠捕捉到這些類別中特定的關鍵詞。然而，對於location和SOFTWARE類別，此方法未能有效識別。
2. **Word2Vec
準確率：0.47**

    Word2Vec在幾乎所有類別中的表現均較差，只有在malware類別上達到較高的 recall。這顯示Word2Vec模型可能沒有很好地捕捉到用於分類的語義信息，或許是模型訓練不足或數據不適合。
3. **SBERT (Sentence-BERT)
準確率：0.60**

    SBERT的整體表現與BoW相當，特別是在malware和vulnerability類別上，顯示了較好的結果。這表明SBERT在捕捉句子級別的語義信息上相對有效。

4. **FastText
準確率：0.49**
    FastText與Word2Vec的表現類似，主要在malware類別上有所表現，其他類別幾乎無法識別。這可能是因為模型訓練不足，或者該方法對於當前的數據集特性不是特別合適。
**總結**
- 最佳表現：BoW和SBERT在多個指標上表現較好，特別是在malware和vulnerability類別上。
- 一般表現：Word2Vec和FastText在大多數類別上表現不佳，這可能需要進一步調整模型參數或進行更深入的特徵工程。

**下一步: 使用 SBERT 作為接下來的特徵提取工具，並使用其他機器學習分類器評估效能**


In [12]:
class SbertDataset(Dataset):
    def __init__(self, descriptions, labels, model):
        self.descriptions = descriptions
        self.labels = labels
        self.model = model

    def __len__(self):
        return len(self.descriptions)

    def __getitem__(self, item):
        description = str(self.descriptions[item])
        label = self.labels[item]
        embedding = self.model.encode(description)

        return {
            'description_text': description,
            'embeddings': torch.tensor(embedding, dtype=torch.float),
            'labels': torch.tensor(label, dtype=torch.long)
        }

In [13]:
label_encoder = LabelEncoder()

data_for_sbert = data_for_cls.copy()
data_for_sbert['label'] = label_encoder.fit_transform(data_for_sbert['label'])

# 分割數據集
train_df, val_df = train_test_split(data_for_sbert, test_size=0.1, random_state=42)

# 初始化SBERT模型
sbert_model = SentenceTransformer('all-MiniLM-L12-v2')

# 創建數據集
train_dataset = SbertDataset(
    descriptions=train_df['clean_text'].to_numpy(),
    labels=train_df['label'].to_numpy(),
    model=sbert_model
)
val_dataset = SbertDataset(
    descriptions=val_df['clean_text'].to_numpy(),
    labels=val_df['label'].to_numpy(),
    model=sbert_model
)

# 創建DataLoader
train_loader = DataLoader(train_dataset, batch_size=16, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=16, num_workers=0)



In [14]:
def extract_features_and_labels(dataloader):
    features = []
    labels = []
    with torch.no_grad():
        for batch in dataloader:
            features.append(batch['embeddings'].numpy())
            labels.append(batch['labels'].numpy())
    # 將列表轉換為NumPy數組
    features = np.vstack(features)
    labels = np.concatenate(labels)
    return features, labels

# 提取訓練和驗證數據的特徵和標籤
X_train, y_train = extract_features_and_labels(train_loader)
X_val, y_val = extract_features_and_labels(val_loader)

In [15]:
# 初始化分類器
classifiers = {
    'Logistic Regression': LogisticRegression(max_iter=1000),
    'SVM': SVC(),
    'Random Forest': RandomForestClassifier()
}

# 訓練每個分類器並評估效能
for name, clf in classifiers.items():
    clf.fit(X_train, y_train)
    y_pred = clf.predict(X_val)
    print(f"{name} Performance:")
    print(classification_report(y_val, y_pred))

Logistic Regression Performance:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         2
           1       0.00      0.00      0.00         4
           2       1.00      0.33      0.50         3
           3       0.00      0.00      0.00         1
           4       0.76      0.86      0.81        22
           5       0.00      0.00      0.00         4
           6       0.50      1.00      0.67         2

    accuracy                           0.58        38
   macro avg       0.32      0.31      0.28        38
weighted avg       0.55      0.58      0.54        38

SVM Performance:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         2
           1       1.00      0.25      0.40         4
           2       1.00      0.67      0.80         3
           3       0.00      0.00      0.00         1
           4       0.81      0.95      0.88        22
           5       0.40     

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Random Forest Performance:
              precision    recall  f1-score   support

           0       1.00      0.50      0.67         2
           1       0.00      0.00      0.00         4
           2       1.00      0.33      0.50         3
           3       0.00      0.00      0.00         1
           4       0.81      0.95      0.88        22
           5       0.29      0.50      0.36         4
           6       0.67      1.00      0.80         2

    accuracy                           0.71        38
   macro avg       0.54      0.47      0.46        38
weighted avg       0.66      0.71      0.66        38



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


### **5-3. Model Fine-tuning**

In [16]:
# 選擇模型和tokenizer
model_checkpoint = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_checkpoint)
model = BertForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)
# print(data_for_cls.head())
def tokenize_function(examples):
    return tokenizer(examples['clean_text'], padding="max_length", truncation=True)

data_for_cls.to_csv('data_for_cls.csv', index=False)
dataset = load_dataset('csv', data_files={'data': 'data_for_cls.csv'})['data']

train_test_split = dataset.train_test_split(test_size=0.1)
train_val_split = train_test_split['train'].train_test_split(test_size=0.2)

dataset_dict_for_cls = DatasetDict({
    'train': train_val_split['train'],
    'validation': train_val_split['test'],
    'test': train_test_split['test']
})

test_data = dataset_dict_for_cls['test']

# 挑出 train, validation set 中的 label 標籤
unique_labels_train = set(dataset_dict_for_cls['train']['label'])
unique_labels_validation = set(dataset_dict_for_cls['validation']['label'])
unique_labels_test = set(dataset_dict_for_cls['test']['label'])
all_unique_labels = unique_labels_train.union(unique_labels_validation)
all_unique_labels = all_unique_labels.union(unique_labels_test)

# 創建 label_id : str - int mapping
label2id = {label: idx for idx, label in enumerate(all_unique_labels)}
id2label = {idx: label for label, idx in label2id.items()}

print("Label to ID mapping:", label2id)

def label_to_id(example):
    example['label'] = label2id[example['label']]
    return example

dataset_dict_for_cls = dataset_dict_for_cls.map(label_to_id, batched=False)

print("Mapped Dataset:", dataset_dict_for_cls['test'][0])

train_dataset = dataset_dict_for_cls['train'].map(tokenize_function, batched=True).remove_columns([col for col in dataset_dict_for_cls['train'].column_names if col not in ["input_ids", "attention_mask", "label"]])
val_dataset = dataset_dict_for_cls['validation'].map(tokenize_function, batched=True).remove_columns([col for col in dataset_dict_for_cls['validation'].column_names if col not in ["input_ids", "attention_mask", "label"]])
test_dataset = dataset_dict_for_cls['test'].map(tokenize_function, batched=True)
print(train_dataset[0])

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Generating data split: 0 examples [00:00, ? examples/s]

Label to ID mapping: {'identity': 0, 'malware': 1, 'SOFTWARE': 2, 'vulnerability': 3, 'threat-actor': 4, 'location': 5, 'attack-pattern': 6}


Map:   0%|          | 0/268 [00:00<?, ? examples/s]

Map:   0%|          | 0/67 [00:00<?, ? examples/s]

Map:   0%|          | 0/38 [00:00<?, ? examples/s]

Mapped Dataset: {'label': 6, 'text': 'Carbanak also performs techniques for disabling security tools, deleting files that are left in malicious activity, and modifying registry to hide configuration information.', 'clean_text': 'carbanak also performs techniques for disabling security tools, deleting files that are left in malicious activity, and modifying registry to hide configuration information.'}


Map:   0%|          | 0/268 [00:00<?, ? examples/s]

Map:   0%|          | 0/67 [00:00<?, ? examples/s]

Map:   0%|          | 0/38 [00:00<?, ? examples/s]

{'label': 6, 'input_ids': [101, 2004, 2540, 3468, 2098, 4473, 1996, 17346, 2000, 26988, 3638, 2013, 1996, 8211, 5080, 1010, 2027, 2064, 12850, 3278, 8310, 1997, 3595, 2592, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [17]:
class ModelTrainer:
    def __init__(self, model_name, num_labels, id2label, label2id):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels, id2label=id2label, label2id=label2id, ignore_mismatched_sizes=True)

    def encode_data(self, df):
        # 對數據進行編碼
        return df.apply(lambda x: self.tokenizer(x['clean_text'], padding='max_length', max_length=512, truncation=True, return_tensors="pt"), axis=1)

    def prepare_dataset(self, df):
        # 準備數據集
        encoded_data = self.encode_data(df)
        # 轉換數據為 PyTorch Dataset 格式
        dataset = [{'input_ids': data['input_ids'].squeeze(), 'attention_mask': data['attention_mask'].squeeze(), 'labels': torch.tensor(label)} for data, label in zip(encoded_data, df['label'])]
        return dataset

    def compute_metrics(self, eval_pred):
        accuracy_metric = load_metric("accuracy")
        precision_metric = load_metric("precision")
        recall_metric = load_metric("recall")
        f1_metric = load_metric("f1")

        logits, labels = eval_pred
        predictions = np.argmax(logits, axis=-1)

        accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
        precision = precision_metric.compute(predictions=predictions, references=labels, average="weighted")
        recall = recall_metric.compute(predictions=predictions, references=labels, average="weighted")
        f1 = f1_metric.compute(predictions=predictions, references=labels, average="weighted")

        return {
            'accuracy': accuracy['accuracy'],
            'precision': precision['precision'],
            'recall': recall['recall'],
            'f1': f1['f1']
        }

    def plot_confusion_matrix(self, eval_dataset):
        predictions = self.trainer.predict(eval_dataset).predictions
        predicted_labels = np.argmax(predictions, axis=1)
        true_labels = [example['labels'].item() for example in eval_dataset]

        cm = confusion_matrix(true_labels, predicted_labels, normalize='true')
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt=".2f", cmap='Blues')
        plt.xlabel('Predicted labels')
        plt.ylabel('True labels')
        plt.title('Confusion Matrix')
        plt.show()

    def train(self, train_dataset, val_dataset, save_path):
        # 設置訓練參數
        training_args = TrainingArguments(
            output_dir=save_path,
            num_train_epochs=12,
            per_device_train_batch_size=32,
            per_device_eval_batch_size=16,
            warmup_steps=500,
            weight_decay=0.01,
            evaluation_strategy="epoch",
            save_strategy="epoch",
            logging_dir=f'{save_path}/logs',
            logging_steps=10,
        )

        # 初始化Trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=val_dataset,
            compute_metrics=self.compute_metrics,
            callbacks=[PlotLossesCallback()]
        )
        trainer.train()
        trainer.save_model(save_path)
        self.tokenizer.save_pretrained(save_path)

        return trainer.evaluate()  # 返回評估結果


In [18]:
class PlotLossesCallback(TrainerCallback):
    def __init__(self):
        self.train_losses = []
        self.eval_losses = []

    def on_epoch_end(self, args, state, control, **kwargs):
        if len(state.log_history) >= 2 and 'loss' in state.log_history[-2] and 'eval_loss' in state.log_history[-1]:
            self.train_losses.append(state.log_history[-2]['loss'])
            self.eval_losses.append(state.log_history[-1]['eval_loss'])

    def on_train_end(self, args, state, control, **kwargs):
        if self.train_losses and self.eval_losses:
            plt.figure(figsize=(10, 5))
            plt.plot(self.train_losses, label='Training Loss')
            plt.plot(self.eval_losses, label='Validation Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title('Training and Validation Loss')
            plt.legend()
            plt.show()
        else:
            print("No loss data available to plot.")


In [19]:
def train_and_save_model(model_name, save_path):
    trainer = ModelTrainer(model_name, num_labels=len(np.unique(data_for_cls['label'])), id2label=id2label, label2id=label2id)
    train_dataset = trainer.prepare_dataset(train_df)
    val_dataset = trainer.prepare_dataset(val_df)
    evaluation_results = trainer.train(train_dataset, val_dataset, save_path)
    print(f"{model_name} evaluation results:", evaluation_results)

# BERT
train_and_save_model('bert-base-uncased', "./bert-base-uncased_fine-tuned_model")

# RoBERTa
train_and_save_model('roberta-base', "./roberta-base_fine-tuned_model")

# DistilBERT
train_and_save_model('distilbert-base-uncased', "./distilbert-base-uncased_fine-tuned_model")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9587,1.97468,0.078947,0.018603,0.078947,0.027412
2,1.9449,1.899379,0.105263,0.02381,0.105263,0.038429
3,1.9068,1.827911,0.184211,0.605817,0.184211,0.181546
4,1.826,1.749946,0.473684,0.466066,0.473684,0.452396
5,1.7815,1.657504,0.578947,0.432018,0.578947,0.494216
6,1.717,1.542298,0.631579,0.472039,0.631579,0.530493
7,1.6623,1.405809,0.631579,0.467654,0.631579,0.532999
8,1.6003,1.286508,0.657895,0.501839,0.657895,0.564052
9,1.4557,1.226967,0.657895,0.501839,0.657895,0.564052
10,1.303,1.147447,0.657895,0.510526,0.657895,0.572874


  accuracy_metric = load_metric("accuracy")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/1.65k [00:00<?, ?B/s]

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/2.58k [00:00<?, ?B/s]

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/2.52k [00:00<?, ?B/s]

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


Downloading builder script:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

  _warn_prf(average, modifier, msg_start, len(result))
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
  _warn_prf(average, modifier, msg_start, len(result))
You can avoid this message in future by passing the argument `trust_remote_code=True`.

No loss data available to plot.


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
  _warn_prf(average, modifier, msg_start, len(result))


bert-base-uncased evaluation results: {'eval_loss': 1.031309962272644, 'eval_accuracy': 0.631578947368421, 'eval_precision': 0.5008166969147005, 'eval_recall': 0.631578947368421, 'eval_f1': 0.5551427588579292, 'eval_runtime': 3.3443, 'eval_samples_per_second': 11.363, 'eval_steps_per_second': 0.897, 'epoch': 12.0}


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9541,1.932779,0.026316,0.000693,0.026316,0.00135
2,1.9415,1.913352,0.026316,0.000693,0.026316,0.00135
3,1.9331,1.882958,0.526316,0.382234,0.526316,0.436636
4,1.9152,1.833657,0.578947,0.33518,0.578947,0.424561
5,1.8702,1.722364,0.578947,0.33518,0.578947,0.424561
6,1.8063,1.434897,0.578947,0.33518,0.578947,0.424561
7,1.7164,1.43268,0.578947,0.33518,0.578947,0.424561
8,1.6893,1.280346,0.578947,0.36391,0.578947,0.446907
9,1.5252,1.22517,0.605263,0.407018,0.605263,0.48655
10,1.2433,1.175841,0.578947,0.459552,0.578947,0.511207


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
  _warn_prf(average, modifier, msg_start, len(result))
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to 

No loss data available to plot.


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
  _warn_prf(average, modifier, msg_start, len(result))


roberta-base evaluation results: {'eval_loss': 1.2486028671264648, 'eval_accuracy': 0.631578947368421, 'eval_precision': 0.593440122044241, 'eval_recall': 0.631578947368421, 'eval_f1': 0.6090225563909776, 'eval_runtime': 3.4563, 'eval_samples_per_second': 10.994, 'eval_steps_per_second': 0.868, 'epoch': 12.0}


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.9592,1.924708,0.026316,0.001548,0.026316,0.002924
2,1.9368,1.893482,0.5,0.354839,0.5,0.415094
3,1.925,1.848003,0.578947,0.33518,0.578947,0.424561
4,1.9015,1.789372,0.578947,0.33518,0.578947,0.424561
5,1.8467,1.686535,0.578947,0.33518,0.578947,0.424561
6,1.8043,1.547465,0.578947,0.33518,0.578947,0.424561
7,1.7485,1.429508,0.578947,0.33518,0.578947,0.424561
8,1.699,1.353103,0.578947,0.374613,0.578947,0.454887
9,1.5777,1.29317,0.578947,0.407228,0.578947,0.477927
10,1.4212,1.233277,0.605263,0.44263,0.605263,0.509169


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
  _warn_prf(average, modifier, msg_start, len(result))
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to 

No loss data available to plot.


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


distilbert-base-uncased evaluation results: {'eval_loss': 1.087432622909546, 'eval_accuracy': 0.6052631578947368, 'eval_precision': 0.47368421052631576, 'eval_recall': 0.6052631578947368, 'eval_f1': 0.5286292654713707, 'eval_runtime': 3.4276, 'eval_samples_per_second': 11.086, 'eval_steps_per_second': 0.875, 'epoch': 12.0}


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.
  _warn_prf(average, modifier, msg_start, len(result))


In [20]:
# dataset_dict_for_cls['test'][0]

In [22]:
# 1. 文本分類 pipeline
def create_text_classification_pipeline(model_path):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    return pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=False)

# 2. zero-shot cls pipeline
def create_zero_shot_pipeline(model_path):
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForSequenceClassification.from_pretrained(model_path)
    return pipeline("zero-shot-classification", model=model, tokenizer=tokenizer)

# 測試資料集
test_texts = test_data['clean_text']
test_labels = test_data['label']

# 提取 label
unique_labels = list(set(test_labels))
print(f"Unique Labels: {unique_labels}")

# BERT text classification
bert_text_classification_pipeline = create_text_classification_pipeline("./bert-base-uncased_fine-tuned_model")
bert_predictions = bert_text_classification_pipeline(test_texts)
print(bert_predictions)

# RoBERTa text classification
roberta_text_classification_pipeline = create_text_classification_pipeline("./roberta-base_fine-tuned_model")
roberta_predictions = roberta_text_classification_pipeline(test_texts)

# DistilBERT text classification
distilbert_text_classification_pipeline = create_text_classification_pipeline("./distilbert-base-uncased_fine-tuned_model")
distilbert_predictions = distilbert_text_classification_pipeline(test_texts)

# calculate accuracy for text classification
def calculate_accuracy(predictions, labels):
    predicted_labels = [pred['label'] for pred in predictions]
    accuracy = np.mean([pred_label == true_label for pred_label, true_label in zip(predicted_labels, labels)])
    return accuracy

bert_accuracy = calculate_accuracy(bert_predictions, test_labels)
roberta_accuracy = calculate_accuracy(roberta_predictions, test_labels)
distilbert_accuracy = calculate_accuracy(distilbert_predictions, test_labels)

print(f"BERT Accuracy: {bert_accuracy}")
print(f"RoBERTa Accuracy: {roberta_accuracy}")
print(f"DistilBERT Accuracy: {distilbert_accuracy}")

# RoBERTa Zero-shot
bert_zero_shot_pipeline = create_zero_shot_pipeline("./bert-base-uncased_fine-tuned_model")
bert_zero_shot_predictions = bert_zero_shot_pipeline(test_texts, candidate_labels=unique_labels)

# RoBERTa Zero-shot
roberta_zero_shot_pipeline = create_zero_shot_pipeline("./roberta-base_fine-tuned_model")
roberta_zero_shot_predictions = roberta_zero_shot_pipeline(test_texts, candidate_labels=unique_labels)

# DistilBERT Zero-shot
distilbert_zero_shot_pipeline = create_zero_shot_pipeline("./distilbert-base-uncased_fine-tuned_model")
distilbert_zero_shot_predictions = distilbert_zero_shot_pipeline(test_texts, candidate_labels=unique_labels)

# BART! 因為實驗發現 zero-shot 對一般語言模型很困難，因此我找了表現很不錯的預訓練模型直接使用作為 benchmark 當作比較基準。
bart_zero_shot_pipeline = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
bart_zero_shot_predictions = bart_zero_shot_pipeline(test_texts, candidate_labels=unique_labels)

# calculate accuracy for zero_shot
def calculate_zero_shot_accuracy(predictions, true_labels):
    correct_predictions = 0
    for pred, true_label in zip(predictions, true_labels):
        predicted_label = pred['labels'][0]
        if predicted_label == true_label:
            correct_predictions += 1
    accuracy = correct_predictions / len(true_labels)
    return accuracy

bert_zero_shot_accuracy = calculate_zero_shot_accuracy(bert_zero_shot_predictions, test_labels)
roberta_zero_shot_accuracy = calculate_zero_shot_accuracy(roberta_zero_shot_predictions, test_labels)
distilbert_zero_shot_accuracy = calculate_zero_shot_accuracy(distilbert_zero_shot_predictions, test_labels)
bart_zero_shot_accuracy = calculate_zero_shot_accuracy(bart_zero_shot_predictions, test_labels)

print(f"BERT Zero-shot Accuracy: {bert_zero_shot_accuracy}")
print(f"RoBERTa Zero-shot Accuracy: {roberta_zero_shot_accuracy}")
print(f"DistilBERT Zero-shot Accuracy: {distilbert_zero_shot_accuracy}")
print(f"BART Zero-shot Accuracy: {bart_zero_shot_accuracy}")

# print out text, predicted label, true label
def print_top_5_predictions(model_name, predictions, texts, true_labels):
    print(f"\nTop 5 Predictions for {model_name} Model:")
    for i in range(min(5, len(predictions))):  # 确保不超出样本数
        predicted_label = predictions[i]['label']  # 直接获取当前预测的标签
        print(f"Text: {texts[i]}")
        print(f"Predicted Label: {predicted_label}")
        print(f"True Label: {true_labels[i]}")
        print("-----")

print_top_5_predictions("BERT Text Classification", bert_predictions, test_texts, test_labels)
print_top_5_predictions("RoBERTa Text Classification", roberta_predictions, test_texts, test_labels)
print_top_5_predictions("DistilBERT Text Classification", distilbert_predictions, test_texts, test_labels)

# print out text, predicted label, true label
def print_top_5_zero_shot_predictions(model_name, predictions, texts, true_labels):
    print(f"\nTop 5 Zero-shot Predictions for {model_name} Model:")
    for i in range(5):
        predicted_label = predictions[i]['labels'][0]
        print(f"Text: {texts[i]}")
        print(f"Predicted Label: {predicted_label}")
        print(f"True Label: {true_labels[i]}")
        print("-----")

print_top_5_zero_shot_predictions("BERT Zero-shot Classification", bert_zero_shot_predictions, test_texts, test_labels)
print_top_5_zero_shot_predictions("RoBERTa Zero-shot Classification", roberta_zero_shot_predictions, test_texts, test_labels)
print_top_5_zero_shot_predictions("DistilBERT Zero-shot Classification", distilbert_zero_shot_predictions, test_texts, test_labels)
print_top_5_zero_shot_predictions("BART Zero-shot Classification", bart_zero_shot_predictions, test_texts, test_labels)


Unique Labels: ['identity', 'SOFTWARE', 'malware', 'vulnerability', 'threat-actor', 'location', 'attack-pattern']
[{'label': 'malware', 'score': 0.3512445092201233}, {'label': 'vulnerability', 'score': 0.26746803522109985}, {'label': 'location', 'score': 0.3851768672466278}, {'label': 'SOFTWARE', 'score': 0.25978484749794006}, {'label': 'malware', 'score': 0.2943859100341797}, {'label': 'threat-actor', 'score': 0.7279812693595886}, {'label': 'attack-pattern', 'score': 0.38648298382759094}, {'label': 'SOFTWARE', 'score': 0.2643638551235199}, {'label': 'attack-pattern', 'score': 0.3078076243400574}, {'label': 'SOFTWARE', 'score': 0.2263854295015335}, {'label': 'threat-actor', 'score': 0.6208673119544983}, {'label': 'location', 'score': 0.5478219389915466}, {'label': 'attack-pattern', 'score': 0.33683881163597107}, {'label': 'threat-actor', 'score': 0.6335844397544861}, {'label': 'threat-actor', 'score': 0.6437937021255493}, {'label': 'threat-actor', 'score': 0.6004820466041565}, {'label'

Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to -1. Define a descriptive label2id mapping in the model config to ensure correct outputs.


BERT Accuracy: 0.02631578947368421
RoBERTa Accuracy: 0.15789473684210525
DistilBERT Accuracy: 0.15789473684210525


Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to -1. Define a descriptive label2id mapping in the model config to ensure correct outputs.
Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to -1. Define a descriptive label2id mapping in the model config to ensure correct outputs.


BERT Zero-shot Accuracy: 0.13157894736842105
RoBERTa Zero-shot Accuracy: 0.10526315789473684
DistilBERT Zero-shot Accuracy: 0.13157894736842105
BART Zero-shot Accuracy: 0.21052631578947367

Top 5 Predictions for BERT Text Classification Model:
Text: carbanak also performs techniques for disabling security tools, deleting files that are left in malicious activity, and modifying registry to hide configuration information.
Predicted Label: malware
True Label: attack-pattern
-----
Text:  apart from argentinian ecommerce provider mercado libre / mercado pago, subsequent victimology has departed south america and pivoted to focus on the high-tech sector.  recent public victims have included:  it should be understood that in addition there are likely any number of other victims, targeted by attacks not known in the public sphere.
Predicted Label: vulnerability
True Label: location
-----
Text: , we were able to observe another github account with the name l4ckyguy, sharing the profile picture,