## Smishing detection using [danielhenel/smishing-detection-mistral-7b-instruct-v0.3](https://huggingface.co/danielhenel/smishing-detection-mistral-7b-instruct-v0.3) - evaluation of the model

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
import pickle
from transformers import pipeline

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# test data
with open("./data/test_data.pkl", "rb") as input_file:
    test_data = pickle.load(input_file)

X_test = test_data["X_test"]
y_test = test_data["y_test"]

total_hams_count = 0
total_smishes_count = 0

for label in y_test:
    if label == "ham":
        total_hams_count += 1
    if label == "smish":
        total_smishes_count += 1

print("There is {} hams and {} smishes in the test dataset.".format(total_hams_count, total_smishes_count))

There is 954 hams and 161 smishes in the test dataset.


In [3]:
base_model = 'mistralai/Mistral-7B-Instruct-v0.3'
adapter_model = 'danielhenel/smishing-detection-mistral-7b-instruct-v0.3'

bnb_config = BitsAndBytesConfig(  
    load_in_4bit= True,
    bnb_4bit_quant_type= "nf4",
    bnb_4bit_compute_dtype= torch.bfloat16,
    bnb_4bit_use_double_quant= False,
)
model = AutoModelForCausalLM.from_pretrained(
        base_model,
        quantization_config=bnb_config,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True,
)
model.config.use_cache = False
model.config.pretraining_tp = 1
model.load_adapter(adapter_model)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(adapter_model, trust_remote_code=True)

Loading checkpoint shards: 100%|██████████| 3/3 [00:24<00:00,  8.10s/it]


In [4]:
false_hams_indicies = []
false_smishes_indicies = []
false_hams_count = 0
false_smishes_count = 0
true_hams_count = 0
true_smishes_count = 0
errors_count = 0
errors_indicies = []
errors = []


for i in range(len(X_test)):
    prompt = (
    "<s>[INST] Do you think it is a ham or smish message? "
    "Your output should be a single word 'ham' or 'smish'. "
    "Do not write a sentence. "
    "Output is case-sensitive. "
    "SMS content: {}[/INST]"
    ).format(X_test[i])

    pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=500)
    answer = pipe(prompt)
    answer = answer[0]['generated_text'].split("[/INST]")[1].split("</s>")[0].lower().strip()

    if answer not in ["ham", "smish"]:
        errors_count += 1
        errors_indicies.append(i)
        errors.append(answer)
        continue
    elif answer == "ham" and y_test[i] == "ham": # correctly recognized as a ham
        true_hams_count += 1
    elif answer == "smish" and y_test[i] == "smish": # correctly recognized as a smish
        true_smishes_count += 1
    elif answer == "ham" and y_test[i] == "smish": # wrongly recognized as a ham
        false_hams_indicies.append(i)
        false_hams_count += 1
    elif answer == "smish" and y_test[i] == "ham": # wrongly recognized as a smish
        false_smishes_indicies.append(i)
        false_smishes_count += 1

# errors warning
if errors_count != 0:
    if errors_count == 1:
        print("WARNING: {} error".format(errors_count))
    else:
        print("WARNING: {} errors".format(errors_count))

# save results for further analysis
results = {"FN" : false_hams_count, "FP" : false_smishes_count,
           "TN" : true_hams_count, "TP" : true_smishes_count,
           "FN_indicies" : false_hams_indicies, "FP_indicies" : false_smishes_indicies,
            "errors_count" : errors_count, "errors" : errors, "errors_indicies" : errors_indicies}

with open("./results/results_mistral_7b_v0.3_fine_tuned.pkl", 'wb') as handle:
    pickle.dump(results, handle, protocol=pickle.HIGHEST_PROTOCOL)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
  attn_output = torch.nn.functional.scaled_dot_product_attention(




### The performance of our fine-tuned Mistral 7b v0.3 model in smishing detection.

In [5]:
FN = false_hams_count     #FN - messages wrongly recognized as not smishes (hams)
FP = false_smishes_count  #FP - messages wrongly recognized as smishes
TN = true_hams_count      #TN - messages correctly recognized as not smishes (hams)
TP = true_smishes_count   #TP - messages correctly recognized as smishes
TOTAL = FN + FP + TN + TP

In [6]:
print("Messages wrongly recognized as hams: {0:.2f}%".format(FN / TOTAL * 100))
print("Messages wrongly recognized as smishes: {0:.2f}%".format(FP / TOTAL * 100))
print("Messages correctly recognized as hams: {0:.2f}%".format(TN / TOTAL * 100))
print("Messages correctly recognized as smishes: {0:.2f}%".format(TP / TOTAL * 100))

Messages wrongly recognized as hams: 0.00%
Messages wrongly recognized as smishes: 1.04%
Messages correctly recognized as hams: 92.85%
Messages correctly recognized as smishes: 6.11%


#### Accuracy

In [7]:
accuracy = (TP + TN) / TOTAL
print("{0:.2f}%".format(accuracy * 100))

98.96%


#### Recall

In [8]:
recall = TP / (TP + FN)
print("{0:.2f}%".format(recall * 100))

100.00%


#### Precision

In [9]:
precision = TP / (TP + FP)
print("{0:.2f}%".format(precision * 100))

85.45%


#### F1 score

In [10]:
F1_score = TP / (TP + (FP + FN) / 2)
print("{0:.2f}%".format(F1_score * 100))

92.16%
