In [None]:
# ============================
# Toxic Comment Classifier (From Scratch)
#  - GRU + Attention
#  - BiLSTM + Attention
#  - Mini-Transformer (no pretrain)
# Colab-ready, no external deps beyond sklearn
# ============================

# ===== 0) Imports & Setup =====
import os, re, html, math, random
import numpy as np
import pandas as pd
from collections import Counter
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader, random_split

from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, roc_auc_score

# Reproducibility
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED); torch.cuda.manual_seed_all(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)

# ===== 1) Config =====
CSV_PATH = "/content/train.csv"   # <-- change to your file path in Colab/Drive
TEXT_COL = "comment_text"
LABEL_COLS = ['toxic','severe_toxic','obscene','threat','insult','identity_hate']

# Choose one: "gru", "lstm", or "transformer"
MODEL_TYPE = "gru"                 # <-- set here
MAX_LEN = 128
MIN_FREQ = 2                       # min token frequency for vocab
BATCH_SIZE = 64                    # bump on Colab; reduce on small GPU
EMBED_DIM = 128
HIDDEN_DIM = 128
D_MODEL = 256                      # for transformer
EPOCHS = 8
LR = 2e-3 if MODEL_TYPE != "transformer" else 1e-3
PATIENCE = 2
THRESHOLD = 0.5

# ===== 2) Load Data =====
df = pd.read_csv(CSV_PATH)
assert TEXT_COL in df.columns, f"'{TEXT_COL}' not in CSV columns"
for c in LABEL_COLS:
    assert c in df.columns, f"'{c}' not in CSV"

print(df[TEXT_COL].head())

# ===== 3) Basic Cleaning & Tokenization =====
def clean_text(text: str) -> str:
    text = str(text)
    text = html.unescape(text)
    text = text.lower()
    text = re.sub(r"http\S+|www\S+|https\S+", " ", text)
    text = re.sub(r"\S+@\S+", " ", text)
    text = re.sub(r"[^a-z0-9\s']", " ", text)  # keep simple
    text = re.sub(r"\s+", " ", text).strip()
    return text

def simple_tokenize(text: str):
    # whitespace + keep apostrophes for contractions
    return clean_text(text).split()

# Build vocab
special_tokens = ["<PAD>", "<UNK>"]
counter = Counter()
for t in tqdm(df[TEXT_COL].astype(str), desc="Building vocab"):
    counter.update(simple_tokenize(t))

itos = special_tokens + [tok for tok, freq in counter.items() if freq >= MIN_FREQ and tok not in special_tokens]
stoi = {tok: i for i, tok in enumerate(itos)}
PAD_IDX = stoi["<PAD>"]; UNK_IDX = stoi["<UNK>"]
vocab_size = len(itos)
print("Vocab size:", vocab_size)

def encode(text, max_len=MAX_LEN):
    toks = simple_tokenize(text)
    ids = [stoi.get(tok, UNK_IDX) for tok in toks[:max_len]]
    if len(ids) < max_len:
        ids = ids + [PAD_IDX] * (max_len - len(ids))
    return ids

# Vectorize the whole dataset
all_ids = np.vstack([encode(t) for t in tqdm(df[TEXT_COL].astype(str), desc="Encoding texts")])
all_labels = df[LABEL_COLS].astype('float32').values

# ===== 4) Train/Val/Test Split =====
N = len(df)
idxs = np.arange(N)
np.random.shuffle(idxs)

train_ratio, val_ratio = 0.8, 0.1
n_train = int(train_ratio * N)
n_val = int(val_ratio * N)
train_idx = idxs[:n_train]
val_idx = idxs[n_train:n_train+n_val]
test_idx = idxs[n_train+n_val:]

X_train = torch.tensor(all_ids[train_idx], dtype=torch.long)
y_train = torch.tensor(all_labels[train_idx], dtype=torch.float32)
X_val   = torch.tensor(all_ids[val_idx], dtype=torch.long)
y_val   = torch.tensor(all_labels[val_idx], dtype=torch.float32)
X_test  = torch.tensor(all_ids[test_idx], dtype=torch.long)
y_test  = torch.tensor(all_labels[test_idx], dtype=torch.float32)

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
val_loader   = DataLoader(TensorDataset(X_val, y_val), batch_size=BATCH_SIZE, shuffle=False, drop_last=False)
test_loader  = DataLoader(TensorDataset(X_test, y_test), batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

print(f"Splits -> train: {len(X_train)}, val: {len(X_val)}, test: {len(X_test)}")

# ===== 5) Models =====
class AttnPool(nn.Module):
    # simple additive attention over sequence outputs
    def __init__(self, in_dim):
        super().__init__()
        self.attn = nn.Linear(in_dim, 1)

    def forward(self, seq_out, mask=None):
        # seq_out: (B, T, C)
        scores = self.attn(seq_out).squeeze(-1)  # (B, T)
        if mask is not None:
            scores = scores.masked_fill(mask==0, -1e9)
        weights = torch.softmax(scores, dim=1)    # (B, T)
        pooled = torch.bmm(weights.unsqueeze(1), seq_out).squeeze(1)  # (B, C)
        return pooled, weights

class GRUAttnClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_labels, pad_idx=0):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.gru = nn.GRU(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.attn = AttnPool(hidden_dim*2)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_dim*2, num_labels)

    def forward(self, x):
        # x: (B, T)
        mask = (x != PAD_IDX).long()
        e = self.emb(x)                            # (B,T,E)
        y, _ = self.gru(e)                         # (B,T,2H)
        pooled, _ = self.attn(y, mask=mask)        # (B,2H)
        out = self.dropout(pooled)
        return self.fc(out)                        # logits (B, L)

class LSTMAttnClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_labels, pad_idx=0):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, embed_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.attn = AttnPool(hidden_dim*2)
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(hidden_dim*2, num_labels)

    def forward(self, x):
        mask = (x != PAD_IDX).long()
        e = self.emb(x)
        y, _ = self.lstm(e)
        pooled, _ = self.attn(y, mask=mask)
        out = self.dropout(pooled)
        return self.fc(out)

# Positional encoding for Transformer
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=MAX_LEN):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))  # (1, T, C)

    def forward(self, x):
        # x: (B,T,C)
        return x + self.pe[:, :x.size(1), :]

class MiniTransformerClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, num_labels, pad_idx=0, dim_ff=512, dropout=0.1):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_idx)
        self.pos = PositionalEncoding(d_model)
        enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, dim_feedforward=dim_ff, dropout=dropout, batch_first=True)
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(d_model, num_labels)

    def forward(self, x):
        # x: (B, T)
        key_padding_mask = (x == PAD_IDX)  # True to mask
        e = self.emb(x)                    # (B,T,C)
        e = self.pos(e)
        h = self.encoder(e, src_key_padding_mask=key_padding_mask)
        # Pooling: mean over non-pad tokens
        mask = (~key_padding_mask).unsqueeze(-1)   # (B,T,1)
        h_masked = h * mask
        summed = h_masked.sum(dim=1)
        lengths = mask.sum(dim=1).clamp(min=1)
        pooled = summed / lengths
        out = self.dropout(pooled)
        return self.fc(out)                # logits (B, L)

# Build the chosen model
num_labels = len(LABEL_COLS)
if MODEL_TYPE == "gru":
    model = GRUAttnClassifier(vocab_size, EMBED_DIM, HIDDEN_DIM, num_labels, pad_idx=PAD_IDX)
elif MODEL_TYPE == "lstm":
    model = LSTMAttnClassifier(vocab_size, EMBED_DIM, HIDDEN_DIM, num_labels, pad_idx=PAD_IDX)
elif MODEL_TYPE == "transformer":
    model = MiniTransformerClassifier(vocab_size, D_MODEL, num_heads=4, num_layers=2, num_labels=num_labels, pad_idx=PAD_IDX, dim_ff=512)
else:
    raise ValueError("MODEL_TYPE must be 'gru', 'lstm', or 'transformer'.")

model = model.to(device)
print(model.__class__.__name__, "params:", sum(p.numel() for p in model.parameters())/1e6, "M")

# ===== 6) Train Utilities =====
criterion = nn.BCEWithLogitsLoss()

optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-2)

# Reduce LR on plateau (simple & robust)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True)

def run_epoch(dataloader, train=True):
    model.train(train)
    total_loss = 0.0
    for X, y in tqdm(dataloader, leave=False):
        X = X.to(device); y = y.to(device)
        logits = model(X)
        loss = criterion(logits, y)
        if train:
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        total_loss += loss.item() * X.size(0)
    return total_loss / len(dataloader.dataset)

def evaluate_thresholded(dataloader, threshold=THRESHOLD):
    model.eval()
    all_probs, all_labels = [], []
    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            logits = model(X)
            probs = torch.sigmoid(logits).cpu().numpy()
            all_probs.append(probs)
            all_labels.append(y.numpy())
    probs = np.vstack(all_probs)
    y_true = np.vstack(all_labels)
    y_pred = (probs >= threshold).astype(int)

    # Per-label metrics
    per_label = {}
    for i, label in enumerate(LABEL_COLS):
        try:
            auc = roc_auc_score(y_true[:, i], probs[:, i])
        except ValueError:
            auc = float('nan')
        per_label[label] = {
            "accuracy": float(accuracy_score(y_true[:, i], y_pred[:, i])),
            "precision": float(precision_score(y_true[:, i], y_pred[:, i], zero_division=0)),
            "recall": float(recall_score(y_true[:, i], y_pred[:, i], zero_division=0)),
            "f1": float(f1_score(y_true[:, i], y_pred[:, i], zero_division=0)),
            "roc_auc": float(auc),
        }

    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    micro_f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)
    return per_label, macro_f1, micro_f1

# ===== 7) Train Loop (with Early Stopping) =====
best_val = float('inf')
bad_epochs = 0
BEST_PATH = f"best_{MODEL_TYPE}_toxicity.pt"

for epoch in range(1, EPOCHS+1):
    print(f"\nEpoch {epoch}/{EPOCHS} (lr={optimizer.param_groups[0]['lr']:.2e})")
    train_loss = run_epoch(train_loader, train=True)
    val_loss   = run_epoch(val_loader, train=False)
    scheduler.step(val_loss)

    print(f" Train loss: {train_loss:.4f} | Val loss: {val_loss:.4f}")

    if val_loss < best_val - 1e-4:
        best_val = val_loss
        bad_epochs = 0
        torch.save(model.state_dict(), BEST_PATH)
        print("  ✅ Improved, model saved.")
    else:
        bad_epochs += 1
        print(f"  ⚠ No improvement ({bad_epochs}/{PATIENCE})")
        if bad_epochs >= PATIENCE:
            print("  ⛔ Early stopping.")
            break

# Load best
model.load_state_dict(torch.load(BEST_PATH, weights_only=False))
model.eval();

# ===== 8) Final Evaluation =====
per_label, macro_f1, micro_f1 = evaluate_thresholded(test_loader, threshold=THRESHOLD)
print("\n=== Test Metrics ===")
for k, v in per_label.items():
    print(f"{k:14s}  acc={v['accuracy']:.3f}  prec={v['precision']:.3f}  rec={v['recall']:.3f}  f1={v['f1']:.3f}  auc={v['roc_auc']:.3f}")
print(f"\nMacro F1: {macro_f1:.3f} | Micro F1: {micro_f1:.3f}")

# ===== 9) Inference on your own text =====
def predict_comment(text, threshold=THRESHOLD):
    ids = torch.tensor([encode(text)], dtype=torch.long, device=device)
    with torch.no_grad():
        logits = model(ids)
        probs = torch.sigmoid(logits).cpu().numpy()[0]
    preds = (probs >= threshold).astype(int)
    print("\nComment:", text)
    for label, p, z in zip(LABEL_COLS, probs, preds):
        print(f"{label:14s}: {'Yes' if z==1 else 'No '} (prob: {p:.2f})")
    return {label: float(p) for label, p in zip(LABEL_COLS, probs)}

# Example:
# predict_comment("I hate you and you are awful!")
