In [1]:
import os
import re
import zipfile
import math
import random
from collections import Counter

import requests
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, precision_recall_fscore_support,
    confusion_matrix, roc_auc_score
)

# -----------------------------
# 0) Reproducibility
# -----------------------------
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# -----------------------------
# 1) Download + Load Dataset
# -----------------------------
UCI_ZIP_URL = "https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip"
DATA_DIR = "./data_sms"
ZIP_PATH = os.path.join(DATA_DIR, "smsspamcollection.zip")
RAW_PATH = os.path.join(DATA_DIR, "SMSSpamCollection")

def download_and_extract():
    os.makedirs(DATA_DIR, exist_ok=True)

    if not os.path.exists(RAW_PATH):
        if not os.path.exists(ZIP_PATH):
            print(f"Downloading: {UCI_ZIP_URL}")
            r = requests.get(UCI_ZIP_URL, timeout=60)
            r.raise_for_status()
            with open(ZIP_PATH, "wb") as f:
                f.write(r.content)

        print("Extracting...")
        with zipfile.ZipFile(ZIP_PATH, "r") as z:
            z.extractall(DATA_DIR)

    assert os.path.exists(RAW_PATH), "Dataset file not found after extraction."
    print("Dataset ready:", RAW_PATH)

def load_sms():
    texts, labels = [], []
    with open(RAW_PATH, "r", encoding="utf-8") as f:
        for line in f:
            # format: label \t text
            parts = line.rstrip("\n").split("\t", maxsplit=1)
            if len(parts) != 2:
                continue
            lab, txt = parts
            labels.append(1 if lab == "spam" else 0)
            texts.append(txt)
    return texts, labels

# -----------------------------
# 2) Tokenize + Vocab
# -----------------------------
TOKEN_RE = re.compile(r"[A-Za-z0-9']+")

def simple_tokenize(text: str):
    # lowercase + keep simple tokens
    return TOKEN_RE.findall(text.lower())

def build_vocab(texts, min_freq=2, max_size=20000, special_tokens=("<pad>", "<unk>")):
    counter = Counter()
    for t in texts:
        counter.update(simple_tokenize(t))

    vocab = {tok: i for i, tok in enumerate(special_tokens)}
    for tok, freq in counter.most_common():
        if freq < min_freq:
            break
        if tok in vocab:
            continue
        vocab[tok] = len(vocab)
        if len(vocab) >= max_size:
            break
    return vocab

def numericalize(text, vocab):
    unk_id = vocab["<unk>"]
    return [vocab.get(tok, unk_id) for tok in simple_tokenize(text)]

# -----------------------------
# 3) Dataset / Collate
# -----------------------------
class SmsDataset(Dataset):
    def __init__(self, texts, labels, vocab, max_len=128):
        self.labels = labels
        self.seqs = []
        for t in texts:
            ids = numericalize(t, vocab)[:max_len]
            self.seqs.append(ids)
        self.max_len = max_len
        self.pad_id = vocab["<pad>"]

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.seqs[idx], self.labels[idx]

def collate_batch(batch, pad_id):
    seqs, labels = zip(*batch)
    max_len = max(len(s) for s in seqs)
    x = torch.full((len(seqs), max_len), pad_id, dtype=torch.long)
    attn_mask = torch.zeros((len(seqs), max_len), dtype=torch.bool)  # True for padding
    for i, s in enumerate(seqs):
        x[i, :len(s)] = torch.tensor(s, dtype=torch.long)
        attn_mask[i, len(s):] = True
    y = torch.tensor(labels, dtype=torch.long)
    return x, attn_mask, y

# -----------------------------
# 4) Transformer Model
# -----------------------------
class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, dropout=0.1, max_len=512):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        if d_model % 2 == 1:
            # if odd dim, last cos part truncated
            pe[:, 1::2] = torch.cos(position * div_term[:-1])
        else:
            pe[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer("pe", pe.unsqueeze(0))  # (1, max_len, d_model)

    def forward(self, x):
        # x: (B, T, D)
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)

class TransformerTextClassifier(nn.Module):
    def __init__(
        self, vocab_size: int, d_model=128, nhead=4, num_layers=2,
        dim_feedforward=256, dropout=0.1, num_classes=2, pad_id=0
    ):
        super().__init__()
        self.pad_id = pad_id
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.pos_enc = PositionalEncoding(d_model, dropout=dropout, max_len=512)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
        self.classifier = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, num_classes)
        )

    def forward(self, input_ids, key_padding_mask):
        # input_ids: (B, T)
        x = self.embedding(input_ids)             # (B, T, D)
        x = self.pos_enc(x)                       # (B, T, D)
        x = self.encoder(x, src_key_padding_mask=key_padding_mask)  # (B, T, D)

        # mean pooling over non-pad tokens
        non_pad = (~key_padding_mask).float()     # (B, T)
        lengths = non_pad.sum(dim=1).clamp(min=1) # (B,)
        pooled = (x * non_pad.unsqueeze(-1)).sum(dim=1) / lengths.unsqueeze(-1)  # (B, D)

        logits = self.classifier(pooled)          # (B, C)
        return logits

# -----------------------------
# 5) Train / Eval
# -----------------------------
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    all_y, all_pred, all_prob = [], [], []
    for x, pad_mask, y in loader:
        x, pad_mask, y = x.to(DEVICE), pad_mask.to(DEVICE), y.to(DEVICE)
        logits = model(x, pad_mask)
        prob = torch.softmax(logits, dim=-1)[:, 1]  # P(spam)
        pred = torch.argmax(logits, dim=-1)

        all_y.extend(y.cpu().tolist())
        all_pred.extend(pred.cpu().tolist())
        all_prob.extend(prob.cpu().tolist())

    acc = accuracy_score(all_y, all_pred)
    p, r, f1, _ = precision_recall_fscore_support(all_y, all_pred, average="binary", zero_division=0)
    cm = confusion_matrix(all_y, all_pred)
    auc = roc_auc_score(all_y, all_prob)
    return acc, p, r, f1, auc, cm

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0.0
    for x, pad_mask, y in loader:
        x, pad_mask, y = x.to(DEVICE), pad_mask.to(DEVICE), y.to(DEVICE)
        optimizer.zero_grad()
        logits = model(x, pad_mask)
        loss = criterion(logits, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        total_loss += loss.item() * x.size(0)
    return total_loss / len(loader.dataset)

def main():
    download_and_extract()
    texts, labels = load_sms()

    # split
    X_train, X_test, y_train, y_test = train_test_split(
        texts, labels, test_size=0.2, random_state=42, stratify=labels
    )

    vocab = build_vocab(X_train, min_freq=2, max_size=20000)
    pad_id = vocab["<pad>"]

    train_ds = SmsDataset(X_train, y_train, vocab, max_len=128)
    test_ds  = SmsDataset(X_test,  y_test,  vocab, max_len=128)

    train_loader = DataLoader(
        train_ds, batch_size=64, shuffle=True,
        collate_fn=lambda b: collate_batch(b, pad_id)
    )
    test_loader = DataLoader(
        test_ds, batch_size=128, shuffle=False,
        collate_fn=lambda b: collate_batch(b, pad_id)
    )

    model = TransformerTextClassifier(
        vocab_size=len(vocab),
        d_model=128, nhead=4, num_layers=2,
        dim_feedforward=256, dropout=0.1,
        num_classes=2, pad_id=pad_id
    ).to(DEVICE)

    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()

    best_f1 = 0.0
    for epoch in range(1, 11):
        loss = train_one_epoch(model, train_loader, optimizer, criterion)
        acc, p, r, f1, auc, cm = evaluate(model, test_loader)

        print(f"Epoch {epoch:02d} | loss={loss:.4f} | acc={acc:.4f} | P={p:.4f} R={r:.4f} F1={f1:.4f} | AUC={auc:.4f}")
        print("Confusion Matrix:\n", cm)

        if f1 > best_f1:
            best_f1 = f1
            torch.save({"model": model.state_dict(), "vocab": vocab}, "best_sms_transformer.pt")

    print("Best F1:", best_f1)
    print("Saved to best_sms_transformer.pt")

if __name__ == "__main__":
    main()


Downloading: https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip
Extracting...
Dataset ready: ./data_sms\SMSSpamCollection


  output = torch._nested_tensor_from_mask(


Epoch 01 | loss=0.3653 | acc=0.9489 | P=0.9259 R=0.6711 F1=0.7782 | AUC=0.9583
Confusion Matrix:
 [[958   8]
 [ 49 100]]
Epoch 02 | loss=0.1395 | acc=0.9668 | P=0.9242 R=0.8188 F1=0.8683 | AUC=0.9794
Confusion Matrix:
 [[956  10]
 [ 27 122]]
Epoch 03 | loss=0.0958 | acc=0.9731 | P=0.8889 R=0.9128 F1=0.9007 | AUC=0.9853
Confusion Matrix:
 [[949  17]
 [ 13 136]]
Epoch 04 | loss=0.0653 | acc=0.9776 | P=0.9559 R=0.8725 F1=0.9123 | AUC=0.9864
Confusion Matrix:
 [[960   6]
 [ 19 130]]
Epoch 05 | loss=0.0560 | acc=0.9785 | P=0.9032 R=0.9396 F1=0.9211 | AUC=0.9894
Confusion Matrix:
 [[951  15]
 [  9 140]]
Epoch 06 | loss=0.0429 | acc=0.9650 | P=0.8198 R=0.9463 F1=0.8785 | AUC=0.9902
Confusion Matrix:
 [[935  31]
 [  8 141]]
Epoch 07 | loss=0.0342 | acc=0.9830 | P=0.9392 R=0.9329 F1=0.9360 | AUC=0.9896
Confusion Matrix:
 [[957   9]
 [ 10 139]]
Epoch 08 | loss=0.0270 | acc=0.9821 | P=0.9329 R=0.9329 F1=0.9329 | AUC=0.9882
Confusion Matrix:
 [[956  10]
 [ 10 139]]
Epoch 09 | loss=0.0248 | acc=0.9