In [None]:
from __future__ import annotations

import hashlib
import importlib
import os
import random
import re
import string
from concurrent.futures import ProcessPoolExecutor
from pathlib import Path
from typing import NamedTuple

In [None]:
import numpy as np
import ssdeep
import torch
import transformers
from charset_normalizer import detect as cdetect
from datasets import Dataset, DatasetDict, concatenate_datasets
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from torch.optim import AdamW
from tqdm import tqdm
from transformers import (AutoConfig, AutoModelForSequenceClassification,
                          AutoTokenizer, EarlyStoppingCallback,
                          TextClassificationPipeline, Trainer,
                          TrainingArguments)

In [None]:
if not transformers.is_torch_available():
    raise RuntimeError("Torch is not available, make sure your Python env and dependencies are set")
if not torch.cuda.is_available():
    raise RuntimeError("Cuda is not available, please retry on a Cuda capable device")

In [None]:
for i in range(torch.cuda.device_count()):
   print(torch.cuda.get_device_properties(i).name)

In [None]:
_MINUSONE_AVAILABLE = importlib.util.find_spec('pyminusone') is not None
if _MINUSONE_AVAILABLE:
    import pyminusone

## Required variables

In [None]:
# Datasets of legit/malicious pwsh scripts
GOODWARES_DIR: str = ""
assert os.path.exists(GOODWARES_DIR), "Empty goodwares folder"
MALWARES_DIR: str = ""
assert os.path.exists(MALWARES_DIR), "Empty malwares folder"

In [None]:
# Folder containing model bin and tokenizer files
MODEL_FOLDER = ""
assert os.path.exists(MODEL_FOLDER), "Empty model folder"

In [None]:
# Model training output folder
OUT_FOLDER = ""
assert OUT_FOLDER

----

## Dataset

### Pre-processing

In [None]:
_TRUNCATE_FILE_SIZE = 4096
_CMT_REGEX = re.compile(r"\'[^\']*\'|\"[^\"]*\"|(#.*$|<#[\s\S]*?#>)", flags=re.IGNORECASE | re.MULTILINE)
_SSDEEP_THRESHOLD = 10
_SEED = 42

In [None]:
class PreprocessedFile(NamedTuple):
    sha256: str
    ssdeep: str
    content: str
    encoding: str

In [None]:
def read_and_decode(filepath: str) -> tuple[str|None, str|None]:
    try:
        with open(filepath, "rb") as file_obj:
            datab = file_obj.read()
            enc = cdetect(datab)["encoding"]
            if enc is not None:
                file_content = datab.decode(encoding=enc)
            else:
                # fallback on utf8 if detection failed
                file_content = datab.decode(encoding="utf-8")
    except (OSError, UnicodeDecodeError) as ex:
        print(f"Unable to decode file {filepath}, error: {ex}")
        return None, None
    return file_content, enc

def _replace_callback(m: re.Match) -> str:
    if m.group(1):
        return ""
    return m.group(0)

def normalize_text(content: str) -> str | None:
    try:
        content = _CMT_REGEX.sub(_replace_callback, content)
        while len(content) > 0:
            if content[0] in string.whitespace:
                content = content[1:]
            else:
                break
    except Exception as ex:
        print(f"Unable to normalize file, error: {ex}")
        return None
    return content[:_TRUNCATE_FILE_SIZE]

def hash_file(filepath: str, enc: str) -> tuple[str|None, str|None]:
    with open(filepath, 'rb') as f:
        datab = f.read()
        try :
            sha256_hash = hashlib.sha256(datab).hexdigest()
            ssdeep_hash = ssdeep.hash(datab, enc)
        except Exception as ex:
            print(f"Couldn't hash {filepath}: {ex}")
            return None, None
    return sha256_hash, ssdeep_hash

def deobfuscate_with_minusone(content: str) -> str|None:
    try:
        content = pyminusone.deobfuscate_powershell(content)
    except BaseException as ex:
        print(f"Unable to deobfuscate: {ex}")
        return None
    return content

def preprocess_file(filepath: str) -> PreprocessedFile | None:
    content, encoding = read_and_decode(filepath)
    if content is None:
        return None
    content = normalize_text(content)
    if content is None or len(content) == 0:
        return None
    if _MINUSONE_AVAILABLE:
        content = deobfuscate_with_minusone(content)
        if content is None:
            return None
    sha256_hash, ssdeep_hash = hash_file(filepath, encoding)
    if sha256_hash is None or ssdeep_hash is None:
        return None
    return PreprocessedFile(sha256_hash, ssdeep_hash, content, encoding)

def preprocess_folder(folder: str | os.PathLike) -> list[PreprocessedFile]:
    with ProcessPoolExecutor(max_workers=os.cpu_count()//2) as executor:
        file_paths = sorted([os.path.join(root, file) for root, _, files in os.walk(folder) for file in files])
        file_array = []
        with tqdm(total=len(file_paths), desc="Processing files", unit="file") as pbar:
            for r in executor.map(preprocess_file, file_paths):
                if r:
                    file_array.append(r)
                pbar.update(1)
    return file_array

In [None]:
gw = preprocess_folder(GOODWARES_DIR)

In [None]:
mw = preprocess_folder(MALWARES_DIR)

In [None]:
print(f"Got {len(gw)} goodwares, {len(mw)} malwares")

In [None]:
if not gw or not mw:
    raise RuntimeError("Left with no files after preprocessing!")

### Deduplication

In [None]:
# Function to check if a file is a duplicate based on ssdeep hash
def deduplicate_files(preprocessed_files: list[PreprocessedFile], threshold: int = _SSDEEP_THRESHOLD):
    unique_files = set()

    def is_duplicate(current_hash: str):
        for f in unique_files:
            proximity = ssdeep.compare(current_hash, f.ssdeep)
            if proximity > threshold:
                return True
        return False

    with tqdm(total=len(preprocessed_files), desc="Deduplicating files", unit="file") as pbar:
        for pfile in preprocessed_files:
            if not is_duplicate(pfile.ssdeep):
                unique_files.add(pfile)
            pbar.update(1)

    return unique_files

In [None]:
gw_dedup = deduplicate_files(gw)

In [None]:
mw_dedup = deduplicate_files(mw)

In [None]:
print(f"After deduplication: {len(gw_dedup)} goodwares, {len(mw_dedup)} malwares, {len(gw_dedup)+len(mw_dedup)} total")

In [None]:
if not gw_dedup or not mw_dedup:
    raise RuntimeError("No files left after deduplication!")

----

### Splits

In [None]:
dataset_array = []
dataset_array += [{"label":0, "text":el.content, "encoding":el.encoding, "hash":el.sha256} for el in gw_dedup]
dataset_array += [{"label":1, "text":el.content, "encoding":el.encoding, "hash":el.sha256} for el in mw_dedup]

In [None]:
# Shuffle dataset
rd = random.Random(x=_SEED)
rd.shuffle(dataset_array)

In [None]:
# Split into train, test, and validation datasets
train_dataset = Dataset.from_list([item for item in dataset_array if item["hash"][0] in set("0123456789ab")])
test_dataset = Dataset.from_list([item for item in dataset_array if item["hash"][0] in set("cd")])
val_dataset = Dataset.from_list([item for item in dataset_array if item["hash"][0] in set("ef")])
dataset = DatasetDict(
    train=train_dataset,
    validation=val_dataset,
    test=test_dataset,
)

In [None]:
len(train_dataset), len(val_dataset), len(test_dataset)

### Pre-tokenization

In [None]:
def prepare_tokenizer(tokenizer):
    tokenizer.add_special_tokens({"pad_token": "<pad>"})
    tokenizer.model_max_length = 1024
    return tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_FOLDER)
tokenizer = prepare_tokenizer(tokenizer)

In [None]:
def tokenize(el):
    return tokenizer(el["text"], max_length=1024, truncation=True)

In [None]:
tokenized_datasets = dataset.map(tokenize, batched=False, num_proc=os.cpu_count()//2)

In [None]:
tokenized_datasets

## Model

### Init Trainer and model

In [None]:
run_name = "demo"
train_batch_size = 32
num_train_epochs = 4
output_dir = os.path.join(OUT_FOLDER, run_name)
os.makedirs(output_dir, exist_ok=True)

In [None]:
training_args = TrainingArguments(
    output_dir=output_dir,
    evaluation_strategy="steps",
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=train_batch_size,
    per_device_eval_batch_size=train_batch_size,
    eval_steps=10,
    save_steps=100,
    logging_steps=10,
    load_best_model_at_end=True,
    gradient_checkpointing=True,
    gradient_accumulation_steps=1,
    eval_accumulation_steps=1,
    fp16=True,
    bf16=False,
    run_name=run_name,
    disable_tqdm=False
)

In [None]:
configuration = AutoConfig.from_pretrained(MODEL_FOLDER)
configuration.hidden_dropout_prob = 0.
configuration.attention_probs_dropout_prob = 0.2
configuration.classifier_dropout = 0.2
configuration.num_labels = 2
configuration.output_hidden_states = False

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_FOLDER,
    config=configuration,
)

In [None]:
for name, param in model.named_parameters():
    if name.startswith("bert"):
        param.requires_grad = True

### Init optimizer/scheduler

In [None]:
pretrained_parms = model.bert.parameters()
pretrained_names = [f'bert.{k}' for (k, v) in model.bert.named_parameters()]
classifier_parms = [v for k, v in model.named_parameters() if k not in pretrained_names]

optimizer = AdamW(
    [
        {
            'params': pretrained_parms,
            'lr': 2e-5,
        },
        {
            'params': classifier_parms,
            'lr': 1e-4,
        }
    ],
)

In [None]:
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=100,
    num_training_steps=len(train_dataset) / train_batch_size * num_train_epochs,
)

### Init metrics

In [None]:
class MyTrainer(Trainer):
    def log(self, logs: dict[str, float]) -> None:
        logs["LR*1e6"] = self._get_learning_rate() * 1e6
        super().log(logs)


def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted', zero_division=0.0)
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }


### Train

In [None]:
trainer = MyTrainer(
    model=model,
    args=training_args,
    optimizers=(optimizer, lr_scheduler),
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[
        EarlyStoppingCallback(
            early_stopping_patience=30, early_stopping_threshold=1e-3
        )
    ],
)

In [None]:
print("Training...")
trainer.train()

In [None]:
print("Saving last checkpoint of the model")
model.save_pretrained(os.path.join(trainer.args.output_dir, "final_checkpoint"))

### Eval

In [None]:
print("Evaluating on valid set...")
trainer.evaluate(tokenized_datasets["validation"])

In [None]:
print("Evaluating on test set...")
trainer.evaluate(tokenized_datasets["test"])

## Fp/Fn

In [None]:
pipe = TextClassificationPipeline(
    model=model, tokenizer=tokenizer, top_k=None, max_length=1024, truncation=True, device=0
)

In [None]:
full_dataset = concatenate_datasets(
    [tokenized_datasets["test"], tokenized_datasets["train"], tokenized_datasets["validation"]]
)

In [None]:
preds = pipe(full_dataset["text"])

In [None]:
train_mask = np.array([item["hash"][0] in "0123456789ab" for item in full_dataset])
test_mask = np.array([item["hash"][0] in "cd" for item in full_dataset])
valid_mask = np.array([item["hash"][0] in "ef" for item in full_dataset])

In [None]:
predictions = np.array([el["score"] for p in preds for el in p if el["label"] == "LABEL_1"])
ground_truth = np.array(full_dataset["label"], dtype=np.bool_)

In [None]:
def fnr_from_fpr(predictions, ground_truth, fpr_threshold):
    # Sort predictions in descending order
    sorted_indices = np.argsort(predictions)[::-1]
    sorted_predictions = predictions[sorted_indices]
    sorted_ground_truth = ground_truth[sorted_indices]

    # Calculate the cumulative sum of true positives and true negatives
    cum_true_positives = np.cumsum(sorted_ground_truth)
    cum_true_negatives = np.cumsum(1 - sorted_ground_truth)

    # Calculate the total number of positives and negatives
    total_positives = np.sum(sorted_ground_truth)
    total_negatives = len(ground_truth) - total_positives

    # Calculate the False Positive Rate (FPR) for each threshold
    fpr = cum_true_negatives / total_negatives

    # Find the index of the FPR closest to the specified threshold
    fpr_index = np.argmin(np.abs(fpr - fpr_threshold))

    # Calculate the corresponding False Negative Rate (FNR)
    fnr = 1.0 - cum_true_positives[fpr_index] / total_positives

    return fnr, sorted_predictions[fpr_index]

In [None]:
print("Train thresholds")
for fpr_rate in [0.001, 0.005, 0.01, 0.02, 0.05]:
    fnr, threshold = fnr_from_fpr(predictions[train_mask], ground_truth[train_mask], fpr_rate)
    print(f"False Negative Rate: at {fpr_rate*100:1.2f}% FP : {fnr*100:5.2f}% , threshold={threshold:1.6f}")

----

In [None]:
print("Valid thresholds")
for fpr_rate in [0.001, 0.005, 0.01, 0.02, 0.05]:
    fnr, threshold = fnr_from_fpr(predictions[valid_mask], ground_truth[valid_mask], fpr_rate)
    print(f"False Negative Rate: at {fpr_rate*100:1.2f}% FP : {fnr*100:5.2f}% , threshold={threshold:1.6f}")

----

In [None]:
print("Test thresholds")
for fpr_rate in [0.001, 0.005, 0.01, 0.02, 0.05]:
    fnr, threshold = fnr_from_fpr(predictions[test_mask], ground_truth[test_mask], fpr_rate)
    print(f"False Negative Rate: at {fpr_rate*100:1.2f}% FP : {fnr*100:5.2f}% , threshold={threshold:1.6f}")

----

In [None]:
print("Full dataset thresholds")
for fpr_rate in [0.001, 0.005, 0.01, 0.02, 0.05]:
    fnr, threshold = fnr_from_fpr(predictions, ground_truth, fpr_rate)
    print(f"False Negative Rate: at {fpr_rate*100:1.2f}% FP : {fnr*100:5.2f}% , threshold={threshold:1.6f}")

----

In [None]:
fpr_threshold = 0.005 # 0.5% FP
fnr, thresh = fnr_from_fpr(predictions, ground_truth, fpr_threshold)
predictions_bool = predictions > thresh 

In [None]:
FP = []
FN = []
for idx in np.where(predictions_bool != ground_truth)[0]:
    if ground_truth[idx]:
        FN.append((full_dataset["hash"][idx], predictions[idx]))
    else:
        FP.append((full_dataset["hash"][idx], predictions[idx]))

In [None]:
for h, s in sorted(FP, key=lambda x:x[1]):
    print("FP: %s : %2.4f" % (h, s))

In [None]:
for h, s in sorted(FN, key=lambda x:x[1]):
   print("FN: %s : %2.4f" % (h, s))

----

In [None]:
del model
torch.cuda.empty_cache()