# Train models for Suspicious

Use this python notebook to train your models for the AIMailAnalyzer, from a csv dataset.

## Import libraries

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import gridspec
from matplotlib.colors import LinearSegmentedColormap
from collections import Counter
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import train_test_split
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.decomposition import PCA
from imblearn.over_sampling import RandomOverSampler, SMOTE
from imblearn.under_sampling import RandomUnderSampler
from imblearn.pipeline import Pipeline

In [None]:
import hashlib
import re
from collections import defaultdict
import matplotlib.pyplot as plt

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from torch.utils.data import DataLoader, TensorDataset
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix

## Set parameters

Set the parameters for the training and data processing. You can change the parameters dependings on your needs and the data you have.

#### Data

- **CSV_DATASET**: List of CSV files to load as datasets. Each file should contain a column named 'body' with the email content and a column named 'label' with the label for each email.
- **y_encoder**: A dictionary to encode the labels to integers. The labels needs to match the labels in your dataset ('label' column of the csv). For example:
  ```python
  y_encoder = {
    '0_LEGIT_INTERNAL_COMMUNICATION': 0,
    '0_LEGIT_EXTERNAL_COMMUNICATION': 1,
    '1_SPAM': 2,
    '1_NEWSLETTER': 3,
    '2_CLASSIC_PHISHING': 4,
    '2_WHALING_PHISHING': 5,
    '2_CLONE_PHISHING': 6,
    '2_BLACKMAILING_PHISHING': 7,
  }
  ```
- **LABELS**: List of labels to use for the classification. These should match the folder names in your Suspicious mailbox.
- **VECTORIZER**: The model to use for vectorizing the emails.

#### Data visualization

- **class_labels**: List of class labels for visualization.
- **class_colors**: List colors for each class label.
- **safe_suspicious_labels**: List of labels for safe vs suspicious classification.
- **safe_suspicious_colors**: List of colors for safe vs suspicious classification.
- **unwanted_dangerous_labels**: List of labels for unwanted vs dangerous classification.
- **unwanted_dangerous_colors**: List of colors for unwanted vs dangerous classification.
- **safe_unwanted_dangerous_labels**: List of labels for safe vs unwanted vs dangerous classification.
- **safe_unwanted_dangerous_colors**: List of colors for safe vs unwanted vs dangerous classification.
- **safe_labels**: List of labels for safe classification.
- **safe_colors**: List of colors for safe classification.
- **unwanted_labels**: List of labels for unwanted classification.
- **unwanted_colors**: List of colors for unwanted classification.
- **dangerous_labels**: List of labels for dangerous classification.
- **dangerous_colors**: List of colors for dangerous classification.

Example:
```python
class_labels = ["internal", "external", "spam", "newsletter", "classic phishing", "whaling", "clone", "blackmail"]
class_colors = ["mediumseagreen", "limegreen", "orange", "goldenrod", "red", "firebrick", "indianred", "lightcoral"]
safe_suspicious_labels = ["safe", "suspicious"]
safe_suspicious_colors = ["green", "orange"]
unwanted_dangerous_labels = ["unwanted", "dangerous"]
unwanted_dangerous_colors = ["gold", "red"]
safe_unwanted_dangerous_labels = ["safe", "unwanted", "dangerous"]
safe_unwanted_dangerous_colors = ["green", "gold", "red"]
safe_labels = ["internal", "external"]
safe_colors = ["mediumseagreen", "limegreen"]
unwanted_labels = ["spam", "newsletter"]
unwanted_colors = ["gold", "khaki"]
dangerous_labels = ["classic phishing", "whaling", "clone", "blackmail"]
dangerous_colors = ["red", "firebrick", "indianred", "lightcoral"]
```

#### Training

- **TEST_SIZE**: The proportion of the dataset to include in the test
- **RANDOM_STATE**: Controls the shuffling applied to the data before applying the split. Pass an int for reproducible
- **SMOTE_FACTOR**: The factor to increase the minority class by using SMOTE.
- **ROS_FACTOR**: The factor to increase the minority class by using Random Over Sampling.
- **CAP_MULT**: The factor to determine the maximum number of samples per class after capping.
- **OVERSAMPLED_CSV**: If you oversampled emails an other way, you can use this csv file to load the oversampled dataset. `None` if you don't want to use it.
- **BATCH_SIZE**: Number of samples per gradient update.
- **LEARNING_RATE**: The learning rate for the model.
- **EPOCHS**: Number of epochs to train the model.
- **MODEL_OUTPUT_PATH**: Folder to save the trained models.

In [None]:
# Data
CSV_DATASET = ''
y_encoder = {
    
}
LABELS = list(y_encoder.keys())
VECTORIZER = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')

# Data visualization
class_labels = []
class_colors = []
safe_suspicious_labels = []
safe_suspicious_colors = []
unwanted_dangerous_labels = []
unwanted_dangerous_colors = []
safe_unwanted_dangerous_labels = []
safe_unwanted_dangerous_colors = []
safe_labels = []
safe_colors = []
unwanted_labels = []
unwanted_colors = []
dangerous_labels = []
dangerous_colors = []

# Training
TEST_SIZE =  # Recommended: 0.2 - 0.3
RANDOM_STATE = 42

SMOTE_FACTOR =  # Recommended: 1.0 - 1.5
ROS_FACTOR =  # Recommended: 1.0 - 1.25
CAP_MULT =  # Recommended: 1.5 - 2.0
OVERSAMPLED_CSV = None # Recommended: None or 'oversampled_data.csv'

BATCH_SIZE =  # Recommended: 8 - 16
LEARNING_RATE = # Recommended: 0.001 - 0.01
EPOCHS =  # Recommended: 10 - 20

# Note that the recommended values corresponds to a small imbalanced dataset (~1000 samples). For larger datasets, you can increase the batch size and decrease the number of epochs.

MODEL_OUTPUT_PATH = 'models'

In [None]:
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Utils

In [None]:
def get_header_dict_list(msg):
    headers = defaultdict(list)
    for key, value in msg.items():
        headers[key].append(value)
    return headers

def calculate_hash(data, hash_type="sha256"):
    """Calculate the hash of given data."""
    if hash_type == "sha256":
        hasher = hashlib.sha256()
    else:
        raise ValueError(f"Unsupported hash type: {hash_type}")
    
    hasher.update(data)
    return hasher.hexdigest()

In [None]:
def process_body(text: str) -> str:
    patterns = [
        (r"\u00A0|\r", ""),
        (r" +\n", "\n"),
        (r"=\n", ""),
        (r"\[cid:.*?\]\n?", ""),
        (
            r".*THALES (GROUP|ALENIA SPACE) (LIMITED DISTRIBUTION|CONFIDENTIAL).*\n",
            "",
        ),
        (r"Sensitivity:.*\n", ""),
        (r"Critère de diffusion ?:.*\n", ""),
        (r"-----------------------------------------------------------------------------------This email has been detected as potentially unwanted, make sure it is not a phishing attempt by visiting this page: Suspicious email reporting - Thales Cybersecurity (corp.thales) <https://intranet.peopleonline.corp.thales/sites/group/transformation-and-development/thales-cybersecurity/suspicious-email-reporting>You will find there the procedure to follow if this message turns out to be a phishing attempt or spam. Otherwise, please ignore this message.”Ce mail a été détecté comme potentiellement indésirable, assurez-vous qu’il ne s’agit pas d’une tentative d’hameçonnage en  consultant cette page : Signalement des e-mails suspects - Thales Cybersecurité (corp.thales) <https://intranet.peopleonline.corp.thales/sites/group/transformation-and-development/thales-cybersecurity/suspicious-email-reporting>Vous y trouverez la procédure à suivre si ce message s’avère être une tentative d’hameçonnage ou un spam. Si ce n’est pas le cas, merci d’ignorer ce message.-----------------------------------------------------------------------------------*\n", ""),
        (r"-----------------------------------------------------------------------------------This email has been detected as potentially unwanted, make sure it is not a phishing attempt by visiting this page: Suspicious email reporting - Thales Cybersecurity (corp.thales) <https://intranet.peopleonline.corp.thales/sites/group/transformation-and-development/thales-cybersecurity/suspicious-email-reporting>You will find there the procedure to follow if this message turns out to be a phishing attempt or spam. Otherwise, please ignore this message.”Ce mail a été détecté comme potentiellement indésirable, assurez-vous qu’il ne s’agit pas d’une tentative d’hameçonnage en consultant cette page : Signalement des e-mails suspects - Thales Cybersecurité (corp.thales) <https://intranet.peopleonline.corp.thales/sites/group/transformation-and-development/thales-cybersecurity/suspicious-email-reporting>Vous y trouverez la procédure à suivre si ce message s’avère être une tentative d’hameçonnage ou un spam. Si ce n’est pas le cas, merci d’ignorer ce message.-----------------------------------------------------------------------------------*\n", ""),
        (r"-----------------------------------------------------------------------------------This email has been detected as potentially unwanted, make sure it is not a phishing attempt by visiting this page: Suspicious email reporting - Thales Cybersecurity (corp.thales) <https://intranet.peopleonline.corp.thales/sites/group/transformation-and-development/thales-cybersecurity/suspicious-email-reporting>You will find there the procedure to follow if this message turns out to be a phishing attempt or spam. Otherwise, please ignore this message.‚ÄùCe mail a √©t√© d√©tect√© comme potentiellement ind√©sirable, assurez-vous qu‚Äôil ne s‚Äôagit pas d‚Äôune tentative d‚Äôhame√ßonnage en consultant cette page : Signalement des e-mails suspects - Thales Cybersecurit√© (corp.thales) <https://intranet.peopleonline.corp.thales/sites/group/transformation-and-development/thales-cybersecurity/suspicious-email-reporting>Vous y trouverez la proc√©dure √† suivre si ce message s‚Äôav√®re √™tre une tentative d‚Äôhame√ßonnage ou un spam. Si ce n‚Äôest pas le cas, merci d‚Äôignorer ce message.-----------------------------------------------------------------------------------*\n", ""),
        (r"-----------------------------------------------------------------------------------This email has been detected as potentially unwanted, make sure it is not a phishing attempt by visiting this page: Suspicious email reporting - Thales Cybersecurity (corp.thales) <https://intranet.peopleonline.corp.thales/sites/group/transformation-and-development/thales-cybersecurity/suspicious-email-reporting>You will find there the procedure to follow if this message turns out to be a phishing attempt or spam. Otherwise, please ignore this message.\"Ce mail a été détecté comme potentiellement indésirable, assurez-vous qu'il ne s'agit pas d'une tentative d'hameçonnage en  consultant cette page : Signalement des e-mails suspects - Thales Cybersecurité (corp.thales) <https://intranet.peopleonline.corp.thales/sites/group/transformation-and-development/thales-cybersecurity/suspicious-email-reporting>Vous y trouverez la procédure à suivre si ce message s'avère être une tentative d'hameçonnage ou un spam. Si ce n'est pas le cas, merci d'ignorer ce message.-----------------------------------------------------------------------------------*\n", ""),
        (r"-----------------------------------------------------------------------------------This email has been detected as potentially unwanted, make sure it is not a phishing attempt by visiting this page: Suspicious email reporting - Thales Cybersecurity (corp.thales) <https://intranet.peopleonline.corp.thales/sites/group/transformation-and-development/thales-cybersecurity/suspicious-email-reporting>You will find there the procedure to follow if this message turns out to be a phishing attempt or spam. Otherwise, please ignore this message.\"Ce mail a été détecté comme potentiellement indésirable, assurez-vous qu'il ne s'agit pas d'une tentative d'hameçonnage en consultant cette page : Signalement des e-mails suspects - Thales Cybersecurité (corp.thales) <https://intranet.peopleonline.corp.thales/sites/group/transformation-and-development/thales-cybersecurity/suspicious-email-reporting>Vous y trouverez la procédure à suivre si ce message s'avère être une tentative d'hameçonnage ou un spam. Si ce n'est pas le cas, merci d'ignorer ce message.-----------------------------------------------------------------------------------*\n", ""),
        (r"\n{3,}", "\n\n"),
        (r"((From|De).*)\n\n", r"\1\n"),
        (r"^\s+|\s+$", ""),
    ]
    for pat, repl in patterns:
        text = re.sub(pat, repl, text, flags=re.MULTILINE)
    return text

In [None]:
def plot_email_statistics_and_pca(X, y, label_names, colors, plot_title):
    # Count the occurrences of each label (email type)
    label_counts_dict = Counter(y)

    # Create a list of unique labels and their counts
    label_counts = [label_counts_dict.get(label_idx, 0) for label_idx in range(len(label_names))]

    # Define a custom color map
    color_map = {label_name: colors[i] for i, label_name in enumerate(label_names)}

    # Map the colors to each label
    label_color_list = [color_map[label_name] for label_name in label_names]

    # Create a figure with two subplots
    fig = plt.figure(figsize=(15, 6))
    gs = gridspec.GridSpec(1, 2, width_ratios=[2, 3])

    ax0 = plt.subplot(gs[0])
    ax1 = plt.subplot(gs[1])

    # Bar chart for email counts
    bars = ax0.bar(label_names, label_counts, color=label_color_list)
    
    # Add count labels on top of each bar
    for bar, count in zip(bars, label_counts):
        height = bar.get_height()
        ax0.text(bar.get_x() + bar.get_width() / 2, height, str(count), ha='center', va='bottom', fontsize=10, color='black')

    ax0.set_ylabel('Count')
    ax0.set_title('Count of Email Types')
    ax0.set_ylim(0, 1000)  # Set y-axis limit from 0 to 1500

    # Fit PCA on training data
    pca = PCA(n_components=2)
    X_flat = np.array([np.array(row).flatten() for row in X])
    X_train_pca = pca.fit_transform(X_flat)

    # PCA Visualization
    scatter = ax1.scatter(
        X_train_pca[:, 0],
        X_train_pca[:, 1],
        c=y,
        cmap=LinearSegmentedColormap.from_list("custom_cmap", colors),
        alpha=0.6
    )

    fig.colorbar(scatter, ax=ax1, label='Email Types')
    ax1.set_title('PCA Projection of Resampled Training Data')
    ax1.set_xlabel('Principal Component 1')
    ax1.set_ylabel('Principal Component 2')

    # Show the figure
    plt.suptitle(plot_title)
    plt.tight_layout()
    plt.show()

In [None]:
def remove_similar_mails(df, field, threshold=0.9, debug=False):
    vectors = np.vstack(df[field].values)
    similarity_matrix = cosine_similarity(vectors)

    to_remove = set()
    n = len(df)

    for i in tqdm(range(n), desc="Removing similar mails"):
        if i in to_remove:
            continue

        similar_indices = []
        for j in range(i + 1, n):
            if j in to_remove:
                continue
            similarity = similarity_matrix[i, j]
            if similarity > threshold:
                to_remove.add(j)
                similar_indices.append((j, similarity))

        if debug and similar_indices:
            print(f"\n[DEBUG] Original: {df.iloc[i]['body'][:75]}...")
            for j, similarity in similar_indices:
                print(f"[DEBUG] Similar {similarity:.4f} (>{threshold}): {df.iloc[j]['body'][:75]}...")

    keep_indices = [i for i in range(n) if i not in to_remove]
    df_cleaned = df.iloc[keep_indices].reset_index(drop=True)
    return df_cleaned

In [None]:
def oversample_data(X_train, y_train, smote_factor=SMOTE_FACTOR, ros_factor=ROS_FACTOR, random_state=RANDOM_STATE, cap_mult=CAP_MULT):
    original_counts = Counter(y_train)
    mean_n = int(pd.Series(list(original_counts.values())).mean())
    undersample_cap = int(cap_mult * mean_n)

    # -------- ROS (only for classes below mean) --------
    ros_strategy = {}
    for cls, n in original_counts.items():
        if n < mean_n:
            target = min(int(n * ros_factor), mean_n)
            if target > n:
                ros_strategy[cls] = target
    steps = []
    if ros_strategy:
        steps.append(("ros", RandomOverSampler(sampling_strategy=ros_strategy, random_state=random_state)))

    # -------- SMOTE (after ROS, still capped at mean) --------
    # We need to know post-ROS class counts before defining SMOTE
    if steps:
        X_tmp, y_tmp = Pipeline(steps).fit_resample(X_train, y_train)
        after_ros = Counter(y_tmp)
    else:
        X_tmp, y_tmp = X_train, y_train
        after_ros = original_counts

    smote_strategy = {}
    for cls, n in after_ros.items():
        if n < mean_n:
            target = min(int(n * smote_factor), mean_n)
            if target > n:
                smote_strategy[cls] = target

    if smote_strategy:
        # k_neighbors must be smaller than the size of the smallest class
        min_class_n = min(smote_strategy.values()) if smote_strategy else 5
        k = max(2, min(5, min_class_n - 1))  # ensure k is valid
        steps.append(("smote", SMOTE(sampling_strategy=smote_strategy, random_state=random_state, k_neighbors=k)))

    # -------- Undersampling (cap = 2 × mean) --------
    after_smote_preview = Counter(Pipeline(steps).fit_resample(X_train, y_train)[1]) if steps else original_counts
    rus_strategy = {}
    for cls, n in after_smote_preview.items():
        if n > undersample_cap:
            rus_strategy[cls] = undersample_cap
    if rus_strategy:
        steps.append(("rus", RandomUnderSampler(sampling_strategy=rus_strategy, random_state=random_state)))

    # -------- Final fit --------
    if not steps:
        # Nothing to do: dataset already balanced enough
        return X_train, y_train

    pipe = Pipeline(steps)
    X_res, y_res = pipe.fit_resample(X_train, y_train)

    # Debug information
    print(f"mean: {mean_n} | Undersample cap: {undersample_cap}")
    print("Before :", original_counts)
    if "ros" in dict(steps):
        X_ros, y_ros = Pipeline([s for s in steps if s[0] == "ros"]).fit_resample(X_train, y_train)
        print("After ROS :", Counter(y_ros))
    if "smote" in dict(steps):
        X_sm, y_sm = Pipeline([s for s in steps if s[0] in ("ros","smote")]).fit_resample(X_train, y_train)
        print("After SMOTE :", Counter(y_sm))
    if "rus" in dict(steps):
        print("After RUS :", Counter(y_res))
    else:
        print("After :", Counter(y_res))

    return X_res, y_res

In [None]:
class ResNetMLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ResNetMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, input_dim)
        self.fc2 = nn.Linear(input_dim, int(input_dim * 2/3))
        self.fc3 = nn.Linear(int(input_dim * 2/3), int(input_dim * 1/3))
        self.fc4 = nn.Linear(int(input_dim * 1/3), output_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

        self.residual_transform = nn.Linear(input_dim, int(input_dim * 1/3))

    def forward(self, x):
        # First layer
        x1 = self.relu(self.fc1(x))
        x1 = self.dropout(x1)

        # Residual Block
        x2 = self.relu(self.fc2(x1))
        x2 = self.fc3(x2)
        x2 += self.residual_transform(x1)  # Add residual connection
        x2 = self.relu(x2)

        # Output layer
        x3 = self.fc4(x2)
        return x3

In [None]:
def train_model(model, train_loader, num_epochs, model_name, criterion, optimizer, device):
    model.train()
    for epoch in range(num_epochs):
        total_loss = 0
        correct = 0
        total = 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

        for batch_idx, (X_batch, y_batch) in enumerate(progress_bar):
            torch.autograd.set_detect_anomaly(True)

            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            # Forward pass
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()

            # Compute accuracy
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == y_batch).sum().item()
            total += y_batch.size(0)

            # Update tqdm progress bar
            progress_bar.set_postfix(loss=loss.item(), acc=100 * correct / total)

        # Print final metrics on the same line
        epoch_loss = total_loss / len(train_loader)
        epoch_acc = 100 * correct / total
        print(f"\rEpoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%", end="", flush=True)

    # Save the entire model
    torch.save(model.state_dict(), f"{MODEL_OUTPUT_PATH}/{model_name}.pth")
    print(f"\n{model_name} saved successfully!")


def evaluate_model(model, test_loader, target_names, device):
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        for X_batch, y_batch in tqdm(test_loader, desc="Evaluating"):
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            _, predicted = torch.max(outputs, 1)
            y_true.extend(y_batch.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, target_names=target_names))

## Data

### Get data

In [None]:
df_mailbox = pd.read_csv(CSV_DATASET)

In [None]:
df_mailbox.dropna(subset=['body', 'label'], inplace=True)

In [None]:
df_mailbox.insert(0, 'hash', None)

In [None]:
df_mailbox['body'] = df_mailbox['body'].apply(lambda x: process_body(x))
df_mailbox['hash'] = df_mailbox['body'].apply(lambda x: calculate_hash(x.encode('utf-8')))

In [None]:
# Clean data
df_mailbox.drop_duplicates(subset=['body', 'label'], inplace=True)
df_mailbox.dropna(inplace=True,axis=0)

In [None]:
df_mailbox.insert(3, 'vect', None)

In [None]:
# Vectorize mails
mask_not_vectorized = df_mailbox["vect"].isna()
new_bodies = df_mailbox.loc[mask_not_vectorized, "body"].tolist()

if new_bodies:
    indices_to_update = df_mailbox.index[mask_not_vectorized]

    for idx, body in tqdm(zip(indices_to_update, new_bodies), total=len(new_bodies), desc="Encoding emails"):
        vector = VECTORIZER.encode(body, convert_to_numpy=True)
        df_mailbox.at[idx, "vect"] = vector

In [None]:
df_mailbox = remove_similar_mails(df_mailbox, "vect", threshold=0.9, debug=True)

### Data overview

In [None]:
print(f"df_mailbox: {len(df_mailbox)}")

In [None]:
# Plot a bar chart of the categories in df_mailbox
plt.figure(figsize=(10, 6))
folder_counts = df_mailbox['label'].value_counts()
ordered_counts = [folder_counts.get(folder, 0) for folder in LABELS]
plt.bar(LABELS, ordered_counts)
plt.title('Distribution of Emails by Folder')
plt.xlabel('Folder')
plt.ylabel('Number of Emails')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


### Split data

In [None]:
df_mailbox_train, df_mailbox_test = train_test_split(df_mailbox, test_size=TEST_SIZE, random_state=RANDOM_STATE, stratify=df_mailbox["label"])

In [None]:
X_train = np.vstack(df_mailbox_train["vect"].values)
X_test = np.vstack(df_mailbox_test["vect"].values)

y_train = df_mailbox_train["label"].map(y_encoder).values
y_test = df_mailbox_test["label"].map(y_encoder).values

### Before oversampling

In [None]:
plot_email_statistics_and_pca(X_train, y_train, ["internal", "external", "spam", "newsletter", "classic phishing", "whaling", "clone", "blackmail"], ["mediumseagreen", "limegreen", "orange", "goldenrod", "red", "firebrick", "indianred", "lightcoral"], 'Safe vs Suspicious mails in df_THALES_train')

### After oversampling

In [None]:
X_train, y_train = oversample_data(X_train, y_train, smote_factor=SMOTE_FACTOR, ros_factor=ROS_FACTOR, random_state=RANDOM_STATE, cap_mult=CAP_MULT)

In [None]:
if OVERSAMPLED_CSV is not None:
    df_oversampled = pd.read_csv(OVERSAMPLED_CSV)

    df_oversampled.dropna(subset=['body', 'label'], inplace=True)
    df_oversampled.insert(0, 'hash', None)

    df_oversampled['body'] = df_oversampled['body'].apply(lambda x: process_body(x))
    df_oversampled['hash'] = df_oversampled['body'].apply(lambda x: calculate_hash(x.encode('utf-8')))

    df_oversampled.drop_duplicates(subset=['body', 'label'], inplace=True)
    df_oversampled.dropna(inplace=True,axis=0)

    df_oversampled.insert(3, 'vect', None)

    df_oversampled['vect'] = df_oversampled['body'].apply(lambda x: VECTORIZER.encode(x, convert_to_numpy=True))

    X_oversampled = np.vstack(df_oversampled["vect"].values)
    y_oversampled = df_oversampled["label"].map(y_encoder).values

    X_train = np.vstack((X_train, X_oversampled))
    y_train = np.hstack((y_train, y_oversampled))

In [None]:
plot_email_statistics_and_pca(X_train, y_train, class_labels, class_colors, 'Safe vs Suspicious mails in df_THALES_train')

## Models

### Safe/Suspicious model

#### Data

In [None]:
X_train_safe_suspicious = X_train
X_test_safe_suspicious = X_test
y_train_safe_suspicious = y_train
y_test_safe_suspicious = y_test

In [None]:
y_train_safe_suspicious = np.array([0 if label < 2 else 1 for label in y_train_safe_suspicious])
y_test_safe_suspicious = np.array([0 if label < 2 else 1 for label in y_test_safe_suspicious])

In [None]:
plot_email_statistics_and_pca(X_train_safe_suspicious, y_train_safe_suspicious, safe_suspicious_labels, safe_suspicious_colors, 'Safe vs Suspicious mails in df_THALES_train')

#### Train

In [None]:
# Convert fine-tuning data to PyTorch tensors
X_train_safe_suspicious_tensor = torch.tensor(np.array([np.array(row).flatten() for row in X_train_safe_suspicious]), dtype=torch.float32)
X_test_safe_suspicious_tensor = torch.tensor(np.array([np.array(row).flatten() for row in X_test_safe_suspicious]), dtype=torch.float32)
y_train_safe_suspicious_tensor = torch.tensor(y_train_safe_suspicious, dtype=torch.long)
y_test_safe_suspicious_tensor = torch.tensor(y_test_safe_suspicious, dtype=torch.long)

In [None]:
# Create DataLoader for batch processing
train_safe_suspicious_dataset = TensorDataset(X_train_safe_suspicious_tensor, y_train_safe_suspicious_tensor)
test_safe_suspicious_dataset = TensorDataset(X_test_safe_suspicious_tensor, y_test_safe_suspicious_tensor)

train_safe_suspicious_loader = DataLoader(train_safe_suspicious_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_safe_suspicious_loader = DataLoader(test_safe_suspicious_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Initialize the model, loss function, and optimizer
input_dim = X_train_safe_suspicious_tensor.shape[1]
output_dim = len(np.unique(y_train_safe_suspicious))

In [None]:
# Compute new class weights for fine-tuning dataset
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train_safe_suspicious),
    y=y_train_safe_suspicious
)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

In [None]:
# Initialize the model architecture
model_safe_suspicious = ResNetMLP(input_dim, output_dim).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = optim.Adam(model_safe_suspicious.parameters(), lr=LEARNING_RATE)

In [None]:
# Train model
model_name = f"safe_suspicious_model"
train_model(model_safe_suspicious, train_safe_suspicious_loader, EPOCHS, model_name, criterion, optimizer, device)

#### Evaluate

In [None]:
# Evaluate model
evaluate_model(model_safe_suspicious, test_safe_suspicious_loader, safe_suspicious_labels, device)

### Unwanted/Dangerous model

#### Data

In [None]:
X_train_unwanted_dangerous = X_train[np.isin(y_train, [2, 3, 4, 5, 6, 7])]
y_train_unwanted_dangerous = y_train[np.isin(y_train, [2, 3, 4, 5, 6, 7])]
X_test_unwanted_dangerous = X_test[np.isin(y_test, [2, 3, 4, 5, 6, 7])]
y_test_unwanted_dangerous = y_test[np.isin(y_test, [2, 3, 4, 5, 6, 7])]

In [None]:
y_train_unwanted_dangerous = np.array([0 if label < 4 else 1 for label in y_train_unwanted_dangerous])
y_test_unwanted_dangerous = np.array([0 if label < 4 else 1 for label in y_test_unwanted_dangerous])

In [None]:
plot_email_statistics_and_pca(X_train_unwanted_dangerous, y_train_unwanted_dangerous, ["unwanted", "dangerous"], ["gold", "red"], 'Unwanted vs Dangerous mails in df_THALES_train')

#### Train

In [None]:
# Convert fine-tuning data to PyTorch tensors
X_train_unwanted_dangerous_tensor = torch.tensor(np.array([np.array(row).flatten() for row in X_train_unwanted_dangerous]), dtype=torch.float32)
X_test_unwanted_dangerous_tensor = torch.tensor(np.array([np.array(row).flatten() for row in X_test_unwanted_dangerous]), dtype=torch.float32)
y_train_unwanted_dangerous_tensor = torch.tensor(y_train_unwanted_dangerous, dtype=torch.long)
y_test_unwanted_dangerous_tensor = torch.tensor(y_test_unwanted_dangerous, dtype=torch.long)

In [None]:
# Create DataLoader for batch processing
train_unwanted_dangerous_dataset = TensorDataset(X_train_unwanted_dangerous_tensor, y_train_unwanted_dangerous_tensor)
test_unwanted_dangerous_dataset = TensorDataset(X_test_unwanted_dangerous_tensor, y_test_unwanted_dangerous_tensor)

train_unwanted_dangerous_loader = DataLoader(train_unwanted_dangerous_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_unwanted_dangerous_loader = DataLoader(test_unwanted_dangerous_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Initialize the model, loss function, and optimizer
input_dim = X_train_unwanted_dangerous_tensor.shape[1]
output_dim = len(np.unique(y_train_unwanted_dangerous))

In [None]:
# Compute new class weights for fine-tuning dataset
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train_unwanted_dangerous),
    y=y_train_unwanted_dangerous
)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

In [None]:
# Initialize the model architecture
model_unwanted_dangerous = ResNetMLP(input_dim, output_dim).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = optim.Adam(model_unwanted_dangerous.parameters(), lr=LEARNING_RATE)

In [None]:
# Train model
model_name = f"unwanted_dangerous_model"
train_model(model_unwanted_dangerous, train_unwanted_dangerous_loader, EPOCHS, model_name, criterion, optimizer, device)

#### Evaluate

In [None]:
# Evaluate model
evaluate_model(model_unwanted_dangerous, test_unwanted_dangerous_loader, unwanted_dangerous_labels, device)

### Safe/Suspicious and Unwanted/Dangerous models

#### Data

In [None]:
X_train_global = X_train
X_test_global = X_test
y_train_global = y_train
y_test_global = y_test

In [None]:
y_train_global = np.array([0 if label < 2 else 1 if label < 4 else 2 for label in y_train_global])
y_test_global = np.array([0 if label < 2 else 1 if label < 4 else 2 for label in y_test_global])

In [None]:
plot_email_statistics_and_pca(X_train_global, y_train_global, safe_unwanted_dangerous_labels, safe_unwanted_dangerous_colors, 'Safe vs Unwanted vs Dangerous mails in df_THALES_train')

#### Evaluate

In [None]:
X_test_global_tensor = torch.tensor(np.array([np.array(row).flatten() for row in X_test_global]), dtype=torch.float32)

In [None]:
# Ensure thes model are in evaluation mode
model_safe_suspicious.eval()
model_unwanted_dangerous.eval()

In [None]:
THRESHOLD_SUSPICIOUS = 0.5

with torch.no_grad():
    y_safe_suspicious_pred_logits = model_safe_suspicious(X_test_global_tensor)  # Get logits (raw outputs)

    # Convert logits to probabilities for "suspicious" class (assuming index 1)
    y_safe_suspicious_probs_suspicious = torch.softmax(y_safe_suspicious_pred_logits, dim=1)[:, 1]

    # Classify as suspicious if the probability exceeds the threshold
    y_safe_suspicious_pred = (y_safe_suspicious_probs_suspicious > THRESHOLD_SUSPICIOUS).int() # 1 for suspicious, 0 for safe

In [None]:
THRESHOLD_DANGEROUS = 0.5

# Initialize final predictions list
final_predictions = []

# Classify emails based on the first model's output
with torch.no_grad():
    for i, is_suspicious in enumerate(y_safe_suspicious_pred):
        if is_suspicious.item() == 0:  # Email classified as "safe"
            final_predictions.append(0)
        else:  # Email classified as "suspicious", proceed to the second model
            email_tensor = X_test_global_tensor[i].unsqueeze(0)  # Add batch dimension for a single email
            spam_dangerous_logits = model_unwanted_dangerous(email_tensor)

            # Convert logits to probabilities
            spam_dangerous_probs = torch.softmax(spam_dangerous_logits, dim=1)

            # Get the probability of "dangerous"
            dangerous_prob = spam_dangerous_probs[0, 1].item()

            # Classify based on the "dangerous" probability
            if dangerous_prob > THRESHOLD_DANGEROUS:
                final_predictions.append(2)  # Dangerous
            else:
                final_predictions.append(1)  # Spam

In [None]:
# Classification report
print(classification_report(y_test_global, final_predictions, target_names=safe_unwanted_dangerous_labels))

# Generate confusion matrix
conf_matrix = confusion_matrix(y_test_global, final_predictions)

# Plotting the confusion matrix
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt="d", cmap="Blues", xticklabels=safe_unwanted_dangerous_labels, yticklabels=safe_unwanted_dangerous_labels)
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.title("Confusion Matrix")
plt.show()

## Precise models

### Safe model

#### Data

In [None]:
X_train_safe = X_train[np.isin(y_train, [0, 1])]
y_train_safe = y_train[np.isin(y_train, [0, 1])]
X_test_safe = X_test[np.isin(y_test, [0, 1])]
y_test_safe = y_test[np.isin(y_test, [0, 1])]

In [None]:
plot_email_statistics_and_pca(X_train_safe, y_train_safe, safe_labels, safe_colors, 'Internal vs External mails in df_THALES_train')

#### Train

In [None]:
# Convert fine-tuning data to PyTorch tensors
X_train_safe_tensor = torch.tensor(np.array([np.array(row).flatten() for row in X_train_safe]), dtype=torch.float32)
X_test_safe_tensor = torch.tensor(np.array([np.array(row).flatten() for row in X_test_safe]), dtype=torch.float32)
y_train_safe_tensor = torch.tensor(y_train_safe, dtype=torch.long)
y_test_safe_tensor = torch.tensor(y_test_safe, dtype=torch.long)

In [None]:
# Create DataLoader for batch processing
train_safe_dataset = TensorDataset(X_train_safe_tensor, y_train_safe_tensor)
test_safe_dataset = TensorDataset(X_test_safe_tensor, y_test_safe_tensor)

train_safe_loader = DataLoader(train_safe_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_safe_loader = DataLoader(test_safe_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Initialize the model, loss function, and optimizer
input_dim = X_train_safe_tensor.shape[1]
output_dim = len(np.unique(y_train_safe))

In [None]:
# Compute new class weights for fine-tuning dataset
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train_safe),
    y=y_train_safe
)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

In [None]:
# Initialize the model architecture
model_safe = ResNetMLP(input_dim, output_dim).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = optim.Adam(model_safe.parameters(), lr=LEARNING_RATE)

In [None]:
# Train model
model_name = f"safe_model"
train_model(model_safe, train_safe_loader, EPOCHS, model_name, criterion, optimizer, device)

#### Evaluate

In [None]:
# Evaluate model
evaluate_model(model_safe, test_safe_loader, safe_labels, device)

### Unwanted model

#### Data

In [None]:
X_train_unwanted = X_train[np.isin(y_train, [2, 3])]
y_train_unwanted = y_train[np.isin(y_train, [2, 3])]
X_test_unwanted = X_test[np.isin(y_test, [2, 3])]
y_test_unwanted = y_test[np.isin(y_test, [2, 3])]

In [None]:
y_train_unwanted = np.array([0 if label == 2 else 1 for label in y_train_unwanted])
y_test_unwanted = np.array([0 if label == 2 else 1 for label in y_test_unwanted])

In [None]:
plot_email_statistics_and_pca(X_train_unwanted, y_train_unwanted, unwanted_labels, unwanted_colors, 'Spam vs Newsletter mails in df_THALES_train')

#### Train

In [None]:
# Convert fine-tuning data to PyTorch tensors
X_train_unwanted_tensor = torch.tensor(np.array([np.array(row).flatten() for row in X_train_unwanted]), dtype=torch.float32)
X_test_unwanted_tensor = torch.tensor(np.array([np.array(row).flatten() for row in X_test_unwanted]), dtype=torch.float32)
y_train_unwanted_tensor = torch.tensor(y_train_unwanted, dtype=torch.long)
y_test_unwanted_tensor = torch.tensor(y_test_unwanted, dtype=torch.long)

In [None]:
# Create DataLoader for batch processing
train_unwanted_dataset = TensorDataset(X_train_unwanted_tensor, y_train_unwanted_tensor)
test_unwanted_dataset = TensorDataset(X_test_unwanted_tensor, y_test_unwanted_tensor)

train_unwanted_loader = DataLoader(train_unwanted_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_unwanted_loader = DataLoader(test_unwanted_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Initialize the model, loss function, and optimizer
input_dim = X_train_unwanted_tensor.shape[1]
output_dim = len(np.unique(y_train_unwanted))

In [None]:
# Compute new class weights for fine-tuning dataset
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train_unwanted),
    y=y_train_unwanted
)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

In [None]:
# Initialize the model architecture
model_unwanted = ResNetMLP(input_dim, output_dim).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = optim.Adam(model_unwanted.parameters(), lr=LEARNING_RATE)

In [None]:
# Train model
model_name = f"unwanted_model"
train_model(model_unwanted, train_unwanted_loader, EPOCHS, model_name, criterion, optimizer, device)

#### Evaluate

In [None]:
# Evaluate model
evaluate_model(model_unwanted, test_unwanted_loader, unwanted_labels, device)

### Dangerous model

#### Data

In [None]:
X_train_dangerous = X_train[np.isin(y_train, [4, 5, 6, 7])]
y_train_dangerous = y_train[np.isin(y_train, [4, 5, 6, 7])]
X_test_dangerous = X_test[np.isin(y_test, [4, 5, 6, 7])]
y_test_dangerous = y_test[np.isin(y_test, [4, 5, 6, 7])]

In [None]:
y_train_dangerous = np.array([label - 4 for label in y_train_dangerous])
y_test_dangerous = np.array([label - 4 for label in y_test_dangerous])

In [None]:
plot_email_statistics_and_pca(X_train_dangerous, y_train_dangerous, dangerous_labels, dangerous_colors, 'Classic phishing vs Whaling vs Clone vs Blackmail mails in df_THALES_train')

#### Train

In [None]:
# Convert fine-tuning data to PyTorch tensors
X_train_dangerous_tensor = torch.tensor(np.array([np.array(row).flatten() for row in X_train_dangerous]), dtype=torch.float32)
X_test_dangerous_tensor = torch.tensor(np.array([np.array(row).flatten() for row in X_test_dangerous]), dtype=torch.float32)
y_train_dangerous_tensor = torch.tensor(y_train_dangerous, dtype=torch.long)
y_test_dangerous_tensor = torch.tensor(y_test_dangerous, dtype=torch.long)

In [None]:
# Create DataLoader for batch processing
train_dangerous_dataset = TensorDataset(X_train_dangerous_tensor, y_train_dangerous_tensor)
test_dangerous_dataset = TensorDataset(X_test_dangerous_tensor, y_test_dangerous_tensor)

train_dangerous_loader = DataLoader(train_dangerous_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dangerous_loader = DataLoader(test_dangerous_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Initialize the model, loss function, and optimizer
input_dim = X_train_dangerous_tensor.shape[1]
output_dim = len(np.unique(y_train_dangerous))

In [None]:
# Compute new class weights for fine-tuning dataset
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(y_train_dangerous),
    y=y_train_dangerous
)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

In [None]:
# Initialize the model architecture
model_dangerous = ResNetMLP(input_dim, output_dim).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = optim.Adam(model_dangerous.parameters(), lr=LEARNING_RATE)

In [None]:
# Train model
model_name = f"dangerous_model"
train_model(model_dangerous, train_dangerous_loader, EPOCHS, model_name, criterion, optimizer, device)

#### Evaluate

In [None]:
# Evaluate model
evaluate_model(model_dangerous, test_dangerous_loader, dangerous_labels, device)