In [1]:
from datasets import load_dataset

ds = load_dataset("google/civil_comments")


In [2]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

import torch.nn as nn

# configuration
label_cols = [
    "toxicity",
    "severe_toxicity",
    "obscene",
    "threat",
    "insult",
    "identity_attack",
    "sexual_explicit",
]
MAX_LEN = 128
BATCH_SIZE = 32
EPOCHS = 2

train_ds = ds["train"]
val_ds = ds["validation"]
test_ds = ds["test"] if "test" in ds else None
print(f"Full train size: {len(train_ds)}")

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")


def preprocess(batch):
    enc = tokenizer(
        batch["text"], truncation=True, padding="max_length", max_length=MAX_LEN
    )
    # binarize labels at 0.5
    labels = []
    for i in range(len(batch["text"])):
        labels.append([1 if float(batch[c][i]) >= 0.5 else 0 for c in label_cols])
    enc["labels"] = labels
    return enc


train_enc = train_ds.map(preprocess, batched=True, remove_columns=train_ds.column_names)
val_enc = val_ds.map(preprocess, batched=True, remove_columns=val_ds.column_names)

test_enc = None
if test_ds is not None:
    test_enc = test_ds.map(
        preprocess, batched=True, remove_columns=test_ds.column_names
    )

train_enc.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
val_enc.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
if test_enc is not None:
    test_enc.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

train_loader = DataLoader(train_enc, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_enc, batch_size=BATCH_SIZE * 2, shuffle=False)

test_loader = (
    DataLoader(test_enc, batch_size=BATCH_SIZE * 2, shuffle=False)
    if test_enc is not None
    else None
)


class RNNClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_labels, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_dim * 2, num_labels)

    def forward(self, input_ids, attention_mask):
        x = self.embedding(input_ids)
        outputs, _ = self.lstm(x)  # [B, T, 2H]
        lengths = attention_mask.sum(dim=1)
        idx = (
            (lengths - 1)
            .clamp(min=0)
            .unsqueeze(1)
            .unsqueeze(2)
            .expand(-1, 1, outputs.size(2))
        )
        last_hidden = outputs.gather(1, idx).squeeze(1)
        logits = self.fc(self.dropout(last_hidden))
        return logits


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Full train size: 1804874


Map:   0%|          | 0/97320 [00:00<?, ? examples/s]

Using device: cuda


In [3]:
from tqdm import tqdm
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score


def safe_auc(y_true, y_score):
    # y_true, y_score shape: [N, L]
    au = []
    for i in range(y_true.shape[1]):
        yi = y_true[:, i]
        ysi = y_score[:, i]
        if len(np.unique(yi)) < 2:
            tau.append(np.nan)
        else:
            try:
                val = roc_auc_score(yi, ysi)
                tau.append(val)
            except Exception:
                # in rare numerical cases
                tau.append(np.nan)
    return np.array(tau)


model = RNNClassifier(
    vocab_size=tokenizer.vocab_size,
    embed_dim=128,
    hidden_dim=128,
    num_labels=len(label_cols),
    pad_idx=tokenizer.pad_token_id,
).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-3)

for epoch in range(1, EPOCHS + 1):
    model.train()
    total_loss = 0.0
    for batch in tqdm(train_loader, desc=f"Training epoch {epoch}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].float().to(device)

        logits = model(input_ids, attention_mask)
        loss = criterion(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()

    model.eval()
    val_loss = 0.0
    val_acc = 0.0
    steps = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].float().to(device)

            logits = model(input_ids, attention_mask)
            loss = criterion(logits, labels)
            probs = torch.sigmoid(logits)
            preds = (probs >= 0.5).float()

            val_loss += loss.item()
            val_acc += (preds == labels).float().mean().item()
            steps += 1

    print(
        f"epoch {epoch} | train_loss {total_loss/len(train_loader):.4f} | "
        f"val_loss {val_loss/steps:.4f} | val_acc {val_acc/steps:.4f}"
    )

# Collect full-val and full-test predictions for per-label metrics


def collect_preds(dloader):
    all_y = []
    all_scores = []
    with torch.no_grad():
        for batch in dloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].float()
            logits = model(input_ids, attention_mask)
            probs = torch.sigmoid(logits).cpu()
            all_y.append(labels)
            all_scores.append(probs)
    return torch.cat(all_y, dim=0).numpy(), torch.cat(all_scores, dim=0).numpy()


val_y, val_scores = collect_preds(val_loader)
val_pred = (val_scores >= 0.5).astype(int)

if test_loader is not None:
    test_y, test_scores = collect_preds(test_loader)
    test_pred = (test_scores >= 0.5).astype(int)
else:
    test_y = test_scores = test_pred = None

# Compute per-label metrics
rows = []
for i, label in enumerate(label_cols):
    vy = val_y[:, i]
    vs = val_scores[:, i]
    vp = val_pred[:, i]
    v_auc = np.nan
    if len(np.unique(vy)) > 1:
        v_auc = roc_auc_score(vy, vs)
    v_f1 = f1_score(vy, vp)
    v_acc = accuracy_score(vy, vp)

    row = {"label": label, "val_auc": v_auc, "val_f1": v_f1, "val_acc": v_acc}

    if test_y is not None:
        ty = test_y[:, i]
        ts = test_scores[:, i]
        tp = test_pred[:, i]
        t_auc = np.nan
        if len(np.unique(ty)) > 1:
            t_auc = roc_auc_score(ty, ts)
        t_f1 = f1_score(ty, tp)
        t_acc = accuracy_score(ty, tp)
        row.update({"test_auc": t_auc, "test_f1": t_f1, "test_acc": t_acc})
    else:
        row.update({"test_auc": np.nan, "test_f1": np.nan, "test_acc": np.nan})

    rows.append(row)

metrics_df = pd.DataFrame(
    rows,
    columns=[
        "label",
        "val_auc",
        "val_f1",
        "val_acc",
        "test_auc",
        "test_f1",
        "test_acc",
    ],
)

avg_vals = metrics_df.drop(columns=["label"]).mean(numeric_only=True)
avg_row = {**{"label": "AVG"}, **avg_vals.to_dict()}
metrics_df = pd.concat([metrics_df, pd.DataFrame([avg_row])], ignore_index=True)

metrics_df = metrics_df.round(3)
metrics_df

Training epoch 1: 100%|██████████| 56403/56403 [10:54<00:00, 86.20it/s]


epoch 1 | train_loss 0.0467 | val_loss 0.0403 | val_acc 0.9857


Training epoch 2: 100%|██████████| 56403/56403 [10:46<00:00, 87.19it/s]


epoch 2 | train_loss 0.0413 | val_loss 0.0397 | val_acc 0.9858


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Unnamed: 0,label,val_auc,val_f1,val_acc,test_auc,test_f1,test_acc
0,toxicity,0.95,0.631,0.951,0.95,0.639,0.951
1,severe_toxicity,,0.0,1.0,,0.0,1.0
2,obscene,0.973,0.364,0.996,0.978,0.421,0.996
3,threat,0.955,0.149,0.998,0.953,0.123,0.998
4,insult,0.961,0.657,0.965,0.962,0.667,0.966
5,identity_attack,0.973,0.097,0.993,0.971,0.093,0.993
6,sexual_explicit,0.967,0.241,0.998,0.967,0.237,0.998
7,AVG,0.963,0.305,0.986,0.964,0.311,0.986


In [4]:
# CPU inference benchmarking on test set (RNN multi-label model)
import time
import numpy as np
import pandas as pd
import torch


original_device = next(model.parameters()).device
model_cpu = model.to("cpu").eval()

with torch.no_grad():
    for i, batch in enumerate(test_loader):
        input_ids = batch["input_ids"].to("cpu")
        attention_mask = batch["attention_mask"].to("cpu")
        _ = model_cpu(input_ids, attention_mask)
        break

n_samples = 0
t0 = time.perf_counter()
with torch.no_grad():
    for batch in test_loader:
        input_ids = batch["input_ids"].to("cpu")
        attention_mask = batch["attention_mask"].to("cpu")
        _ = model_cpu(input_ids, attention_mask)
        n_samples += input_ids.size(0)
t1 = time.perf_counter()

total_seconds = t1 - t0
throughput = n_samples / total_seconds if total_seconds > 0 else float("inf")
per_sample_ms_all = (total_seconds / n_samples) * 1000.0

per_label_timings = [
    {"label": label, "test_infer_seconds": total_seconds, "per_sample_ms": per_sample_ms_all}
    for label in label_cols
]

time_metrics_df = pd.DataFrame(per_label_timings)

# Append AVG row
avg_vals = time_metrics_df.drop(columns=["label"]).mean(numeric_only=True)
avg_row = {**{"label": "AVG"}, **avg_vals.to_dict()}
time_metrics_df = pd.concat([time_metrics_df, pd.DataFrame([avg_row])], ignore_index=True)

# Round for readability
time_metrics_df["test_infer_seconds"] = time_metrics_df["test_infer_seconds"].round(6)
time_metrics_df["per_sample_ms"] = time_metrics_df["per_sample_ms"].round(6)

print("Benchmark (CPU) on test set:")
print(f" - samples: {n_samples}")
print(f" - total_inference_seconds_all_labels: {total_seconds:.6f}")
print(f" - throughput_samples_per_sec_all_labels: {throughput:.2f}")
print(f" - avg_per_sample_latency_ms_all_labels: {per_sample_ms_all:.6f}")

_ = model_cpu.to(original_device)

time_metrics_df

Benchmark (CPU) on test set:
 - samples: 97320
 - total_inference_seconds_all_labels: 57.510928
 - throughput_samples_per_sec_all_labels: 1692.20
 - avg_per_sample_latency_ms_all_labels: 0.590947


Unnamed: 0,label,test_infer_seconds,per_sample_ms
0,toxicity,57.510928,0.590947
1,severe_toxicity,57.510928,0.590947
2,obscene,57.510928,0.590947
3,threat,57.510928,0.590947
4,insult,57.510928,0.590947
5,identity_attack,57.510928,0.590947
6,sexual_explicit,57.510928,0.590947
7,AVG,57.510928,0.590947
