# Phishing Email Detector

## Dataset Download and Preparation

In [4]:
from pathlib import Path

# Directory to save the downloaded dataset
DATA_RAW_DIR = Path("data/raw")

# Make sure the directory exists
DATA_RAW_DIR.mkdir(parents=True, exist_ok=True)

DATASET_NAME="subhajournal/phishingemails" # naserabdullahalam/phishing-email-dataset

# Check if datsets are already downloaded
if not any(DATA_RAW_DIR.iterdir()):
    print("Downloading dataset from Kaggle...")
    # Download the dataset from Kaggle
    !kaggle datasets download -d {DATASET_NAME} -p {DATA_RAW_DIR} --unzip
else :
    print("Dataset already exists. Skipping download.")


Downloading dataset from Kaggle...
Dataset URL: https://www.kaggle.com/datasets/subhajournal/phishingemails
License(s): GNU Lesser General Public License 3.0
Downloading phishingemails.zip to data/raw
  0%|                                               | 0.00/18.0M [00:00<?, ?B/s]
100%|███████████████████████████████████████| 18.0M/18.0M [00:00<00:00, 875MB/s]


## Setup

In [None]:
import re
import os
import random
import math
from collections import Counter


import pandas as pd
import numpy as np
from bs4 import BeautifulSoup
from urlextract import URLExtract
from tqdm.auto import tqdm


import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F


from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, roc_auc_score
from sklearn.preprocessing import StandardScaler

import nltk
nltk.download('punkt')
from nltk.tokenize import word_tokenize


# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)


# Config (adjust these as you need)
DATA_CSV = "emails.csv" # path to your CSV
MAX_VOCAB = 40000
MAX_LEN = 256
EMBED_DIM = 128
HIDDEN_DIM = 128
BATCH_SIZE = 64
EPOCHS = 6
LR = 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

[nltk_data] Downloading package punkt to /home/enrico/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


## Utilities

In [None]:
extractor = URLExtract()


def clean_html(text):
    if not isinstance(text, str):
        return ""
    return BeautifulSoup(text, "html.parser").get_text(separator=" ")


def remove_urls(text):
    urls = extractor.find_urls(text)
    for u in urls:
        text = text.replace(u, " URLTOKEN ")
    return text, len(urls), urls


def simple_preprocess(text):
    # collapse whitespace, lowercase
    text = re.sub(r"\s+", " ", str(text))
    text = text.strip()
    text = text.lower()
    return text


def tokenize(text):
    return word_tokenize(text)


# Small keyword feature helper
KEYWORD_PATTERN = re.compile(
    r"\b(password|verify|account|bank|login|confirm|click|urgent|reset)\b"
)

## Vocabulary builder

In [None]:
def build_vocab(token_lists, max_vocab=MAX_VOCAB, min_freq=2):
    counter = Counter()
    for toks in token_lists:
        counter.update(toks)
    vocab_tokens = ["<PAD>", "<UNK>"] + [
        tok for tok, c in counter.most_common(max_vocab) if c >= min_freq
    ]
    stoi = {tok: i for i, tok in enumerate(vocab_tokens)}
    itos = {i: tok for tok, i in stoi.items()}
    return stoi, itos

## Dataset Class

In [None]:
class EmailDataset(Dataset):
    def __init__(self, df, stoi, scaler=None, max_len=MAX_LEN):
        self.df = df.reset_index(drop=True)
        self.stoi = stoi
        self.max_len = max_len
        self.scaler = scaler
        self.url_extractor = URLExtract()

        # Precompute numeric features
        self.num_feats = []
        for _, row in self.df.iterrows():
            subj = row.get("subject", "") or ""
            body = row.get("body", "") or ""
            raw = subj + " " + body
            text = clean_html(raw)
            text = simple_preprocess(text)
            text, n_urls, urls = remove_urls(text)

            n_upper = sum(1 for c in (subj + body) if c.isupper())
            n_exclaim = (subj + body).count("!")
            n_special = sum(
                1 for c in (subj + body) if not c.isalnum() and not c.isspace()
            )
            length = len(text.split())
            has_login_words = int(bool(KEYWORD_PATTERN.search(text)))
            features = [n_urls, n_upper, n_exclaim, n_special, length, has_login_words]
            self.num_feats.append(features)

        self.num_feats = np.array(self.num_feats, dtype=np.float32)
        if scaler is not None:
            self.num_feats = scaler.transform(self.num_feats)

        # Tokenize to ids
        self.seq_ids = []
        for _, row in self.df.iterrows():
            subj = row.get("subject", "") or ""
            body = row.get("body", "") or ""
            text = clean_html(subj + " " + body)
            text = simple_preprocess(text)
            text, _, _ = remove_urls(text)
            toks = tokenize(text)
            ids = [self.stoi.get(tok, self.stoi.get("<UNK>")) for tok in toks][
                : self.max_len
            ]
            self.seq_ids.append(ids)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        seq = self.seq_ids[idx]
        num = self.num_feats[idx].astype(np.float32)
        label = int(self.df.loc[idx, "label"])
        return {
            "seq": torch.tensor(seq, dtype=torch.long),
            "num": torch.tensor(num, dtype=torch.float),
            "label": torch.tensor(label, dtype=torch.long),
        }

    def collate_batch(batch):
        seqs = [item["seq"] for item in batch]
        lengths = torch.tensor([len(s) for s in seqs], dtype=torch.long)
        maxlen = max(lengths).item()
        padded = torch.zeros(len(seqs), maxlen, dtype=torch.long)
        for i, s in enumerate(seqs):
            padded[i, : len(s)] = s
        nums = torch.stack([item["num"] for item in batch])
        labels = torch.tensor([item["label"] for item in batch], dtype=torch.long)
        return {"seq": padded, "lengths": lengths, "num": nums, "label": labels}

## Model

In [None]:
class Attention(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.proj = nn.Linear(dim, 1)


    def forward(self, x, mask=None):
        # x: batch x seq x dim
        scores = self.proj(x).squeeze(-1) # batch x seq
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = torch.softmax(scores, dim=-1).unsqueeze(-1) # batch x seq x 1
        out = (x * attn).sum(dim=1) # batch x dim
        return out, attn


class PhishDetector(nn.Module):
    def __init__(self, vocab_size, emb_dim=EMBED_DIM, hidden_dim=HIDDEN_DIM, num_feats_dim=6, num_classes=2, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=0)
        self.lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.att = Attention(hidden_dim * 2)
        self.fc_text = nn.Linear(hidden_dim * 2, 128)
        self.fc_comb = nn.Linear(128 + num_feats_dim, 64)
        self.out = nn.Linear(64, num_classes)
        self.dropout = nn.Dropout(dropout)


    def forward(self, seq, lengths, num_feats):
        emb = self.embedding(seq)
        packed = nn.utils.rnn.pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
        packed_out, _ = self.lstm(packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(packed_out, batch_first=True)
        mask = (seq != 0).to(seq.device)
        attn_out, attn_weights = self.att(out, mask)
        x = F.relu(self.fc_text(attn_out))
        x = self.dropout(x)
        x = torch.cat([x, num_feats], dim=1)
        x = F.relu(self.fc_comb(x))
        x = self.dropout(x)
        logits = self.out(x)
        return logits, attn_weights

## Training and Evaluation

In [None]:
def compute_metrics(y_true, y_pred_probs, threshold=0.5):
    y_pred = (y_pred_probs[:, 1] >= threshold).astype(int)
    acc = accuracy_score(y_true, y_pred)
    p, r, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="binary", zero_division=0
    )
    try:
        auc = roc_auc_score(y_true, y_pred_probs[:, 1])
    except Exception:
        auc = float("nan")
    return {"accuracy": acc, "precision": p, "recall": r, "f1": f1, "roc_auc": auc}


def train_one_epoch(model, dataloader, opt, criterion):
    model.train()
    losses = []
    for batch in tqdm(dataloader, desc="train step"):
        seq = batch["seq"].to(DEVICE)
        lengths = batch["lengths"].to(DEVICE)
        num = batch["num"].to(DEVICE)
        labels = batch["label"].to(DEVICE)
        opt.zero_grad()
        logits, _ = model(seq, lengths, num)
        loss = criterion(logits, labels)
        loss.backward()
        opt.step()
        losses.append(loss.item())
    return np.mean(losses)


def eval_model(model, dataloader):
    model.eval()
    probs = []
    trues = []
    with torch.no_grad():
        for batch in dataloader:
            seq = batch["seq"].to(DEVICE)
            lengths = batch["lengths"].to(DEVICE)
            num = batch["num"].to(DEVICE)
            labels = batch["label"].to(DEVICE)
            logits, attn = model(seq, lengths, num)
            p = torch.softmax(logits, dim=1).cpu().numpy()
            probs.append(p)
            trues.append(labels.cpu().numpy())
    probs = np.vstack(probs)
    trues = np.concatenate(trues)
    return trues, probs

## Main Training Loop

In [None]:
assert os.path.exists(DATA_CSV), f"CSV file not found: {DATA_CSV}"
df["body"] = df["body"].astype(str)


# Create token lists for vocab building (train-only later)
token_lists = []
for _, row in df.iterrows():
    text = clean_html(row["subject"] + " " + row["body"])
    text = simple_preprocess(text)
    text, _, _ = remove_urls(text)
    toks = tokenize(text)
    token_lists.append(toks)


# Split
train_df, test_df = train_test_split(
    df, test_size=0.15, stratify=df["label"], random_state=SEED
)
train_df, val_df = train_test_split(
    train_df, test_size=0.1, stratify=train_df["label"], random_state=SEED
)


# Build vocab on train only
train_token_lists = [token_lists[i] for i in train_df.index]
stoi, itos = build_vocab(train_token_lists, max_vocab=MAX_VOCAB)
vocab_size = len(stoi)
print("Vocab size:", vocab_size)


# Fit scaler on train
train_ds_tmp = EmailDataset(train_df, stoi, scaler=None)
scaler = StandardScaler().fit(train_ds_tmp.num_feats)


# Create datasets
train_ds = EmailDataset(train_df, stoi, scaler=scaler)
val_ds = EmailDataset(val_df, stoi, scaler=scaler)
test_ds = EmailDataset(test_df, stoi, scaler=scaler)


train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch
)
val_loader = DataLoader(
    val_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch
)
test_loader = DataLoader(
    test_ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_batch
)


# Weighted loss for imbalance
label_counts = train_df["label"].value_counts().sort_index()
total = label_counts.sum()
weights = [total / (2 * c) for c in label_counts]
class_weights = torch.tensor(weights, dtype=torch.float).to(DEVICE)


# Model, loss, optimizer
model = PhishDetector(
    vocab_size=vocab_size,
    emb_dim=EMBED_DIM,
    hidden_dim=HIDDEN_DIM,
    num_feats_dim=train_ds.num_feats.shape[1],
)
model.to(DEVICE)
criterion = nn.CrossEntropyLoss(weight=class_weights)
opt = torch.optim.Adam(model.parameters(), lr=LR)


# Training loop
best_val_f1 = -1
for epoch in range(1, EPOCHS + 1):
    print(f"Epoch {epoch}/{EPOCHS}")
    train_loss = train_one_epoch(model, train_loader, opt, criterion)
    trues_val, probs_val = eval_model(model, val_loader)
    val_metrics = compute_metrics(trues_val, probs_val)
    print(f"Train loss: {train_loss:.4f} Val metrics: {val_metrics}")
    if val_metrics["f1"] > best_val_f1:
        best_val_f1 = val_metrics["f1"]
        torch.save(
            {
                "model_state": model.state_dict(),
                "stoi": stoi,
                "itos": itos,
                "scaler": scaler,
            },
            "best_phish_model.pth",
        )
        print("Saved best model.")

## Evaluation

In [None]:
ckpt = torch.load("best_phish_model.pth", map_location=DEVICE)
model.load_state_dict(ckpt["model_state"])
trues_test, probs_test = eval_model(model, test_loader)
test_metrics = compute_metrics(trues_test, probs_test)
print("Final test metrics:", test_metrics)

## Inference

In [None]:
def load_artifacts(filepath="best_phish_model.pth"):
    ckpt = torch.load(filepath, map_location=DEVICE)
    model = PhishDetector(
        vocab_size=len(ckpt["stoi"]),
        emb_dim=EMBED_DIM,
        hidden_dim=HIDDEN_DIM,
        num_feats_dim=6,
    )
    model.load_state_dict(ckpt["model_state"])
    model.to(DEVICE)
    model.eval()
    return model, ckpt["stoi"], ckpt["scaler"]


def preprocess_single(subject, body, stoi, scaler, max_len=MAX_LEN):
    raw = (subject or "") + " " + (body or "")
    text = clean_html(raw)
    text = simple_preprocess(text)
    text, n_urls, urls = remove_urls(text)
    toks = tokenize(text)
    ids = [stoi.get(tok, stoi.get("<UNK>")) for tok in toks][:max_len]

    # numeric features
    n_upper = sum(1 for c in raw if c.isupper())
    n_exclaim = raw.count("!")
    n_special = sum(1 for c in raw if not c.isalnum() and not c.isspace())
    length = len(text.split())
    has_login_words = int(bool(KEYWORD_PATTERN.search(text)))
    feats = np.array(
        [[n_urls, n_upper, n_exclaim, n_special, length, has_login_words]],
        dtype=np.float32,
    )
    feats = scaler.transform(feats)

    seq = torch.tensor([ids], dtype=torch.long).to(DEVICE)
    lengths = torch.tensor([len(ids)], dtype=torch.long).to(DEVICE)
    num = torch.tensor(feats, dtype=torch.float).to(DEVICE)
    return seq, lengths, num


def predict_email(subject, body, model, stoi, scaler, threshold=0.5):
    seq, lengths, num = preprocess_single(subject, body, stoi, scaler)
    with torch.no_grad():
        logits, attn = model(seq, lengths, num)
        probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
    label = int(probs[1] >= threshold)
    return {
        "prob_safe": float(probs[0]),
        "prob_phish": float(probs[1]),
        "label": label,
        "attn_weights": attn.cpu().numpy() if attn is not None else None,
    }