In [1]:
import spacy
import random
import ast
import re
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import (
    f1_score,
    accuracy_score,
    precision_score,
    recall_score,
    confusion_matrix,
    roc_curve,
    auc,
    ConfusionMatrixDisplay
)
from spacy.training import Example

random.seed = 42

# ========== 1. Load data ==========
df_train = pd.read_csv("output_train.csv")
df_valid = pd.read_csv("output_valid.csv")
df_test  = pd.read_csv("output_test.csv")
# ========== 2. XML tagging utils ==========
def add_offsets(text, entities):
    if not isinstance(entities, list):
        return []
    used, results = [False] * len(text), []
    for ent in entities:
        word = ent.get("word")
        label = ent.get("entity", "").lower()
        if not word or not label:
            continue
        m = re.search(re.escape(word), text, re.IGNORECASE)
        if m and not any(used[m.start():m.end()]):
            results.append({"start": m.start(), "end": m.end(), "entity": label})
            for i in range(m.start(), m.end()):
                used[i] = True
    return results

def merge_adjacent_entities(ents):
    if not ents:
        return []
    ents = sorted(ents, key=lambda x: x["start"])
    merged = [ents[0]]
    for curr in ents[1:]:
        last = merged[-1]
        if curr["entity"] == last["entity"] and curr["start"] <= last["end"] + 1:
            last["end"] = curr["end"]
        else:
            merged.append(curr)
    return merged

def insert_xml_tags(text, entities):
    if isinstance(entities, str):
        try:
            entities = ast.literal_eval(entities)
        except:
            return text
    spans = add_offsets(text, entities)
    if not spans:
        return text
    merged = merge_adjacent_entities(spans)
    offset = 0
    for ent in sorted(merged, key=lambda x: x["start"]):
        o_tag, c_tag = f"<{ent['entity']}>", f"</{ent['entity']}>"
        s, e = ent["start"] + offset, ent["end"] + offset
        text = text[:s] + o_tag + text[s:e] + c_tag + text[e:]
        offset += len(o_tag) + len(c_tag)
    return text

# apply to all splits
for df in (df_train, df_valid, df_test):
    df["B_XML_statement"] = df.apply(
        lambda r: insert_xml_tags(r["statement"], r["B_raw_entities"]),
        axis=1
    )

# ========== 3. Prepare data ==========
def make_data(df, label_col):
    texts  = df["B_XML_statement"].tolist()
    labels = df[label_col].tolist()
    cats   = [{"cats": {"1": bool(l), "0": not bool(l)}} for l in labels]
    return list(zip(texts, cats)), texts, labels

train_data, _, _           = make_data(df_train, "label_binary")
valid_data, valid_texts, valid_labels = make_data(df_valid, "label_binary")
test_data,  test_texts,  test_labels  = make_data(df_test,  "label_binary")

# ========== 4. Build the model ==========
nlp = spacy.blank("en")
textcat = nlp.add_pipe(
    "textcat",
    last=True,
    config={
        "model": {
            "@architectures": "spacy.TextCatBOW.v3",
            "exclusive_classes": True,
            "ngram_size": 1,
            "no_output_layer": False
        }
    }
)
textcat.add_label("0")
textcat.add_label("1")

# ========== 5. Train w/ validation ==========
n_iter       = 10
train_losses = []
val_f1s      = []
optimizer    = nlp.begin_training()

for epoch in range(n_iter):
    random.shuffle(train_data)
    losses = {}
    # training
    for batch in spacy.util.minibatch(train_data, size=8):
        examples = []
        for text, ann in batch:
            doc = nlp.make_doc(text)
            examples.append(Example.from_dict(doc, ann))
        nlp.update(examples, drop=0.2, sgd=optimizer, losses=losses)
    train_losses.append(losses["textcat"])

    # validation eval
    preds      = [nlp(txt).cats for txt in valid_texts]
    pred_bin   = [int(p["1"] >= 0.5) for p in preds]
    f1         = f1_score(valid_labels, pred_bin)*100
    val_f1s.append(f1)

    print(
        f"Epoch {epoch+1}/{n_iter} — "
        f"Train Loss: {losses['textcat']:.4f} — "
        f"Val F1: {f1:.4f}%"
    )

# ========== 6. Test evaluation ==========
test_preds = [nlp(t).cats for t in test_texts]
test_bin   = [int(p["1"] >= 0.5) for p in test_preds]
scores     = {
    "accuracy":  accuracy_score(test_labels, test_bin)*100,
    "f1":        f1_score(test_labels, test_bin)*100,
    "precision": precision_score(test_labels, test_bin)*100,
    "recall":    recall_score(test_labels, test_bin)*100
}
print("📊 Test scores:", scores)

# ========== 7. Plots & save ==========
# Training + Validation curve
plt.figure()
plt.plot(range(1, n_iter+1), train_losses, marker="o", label="Train Loss")
plt.plot(range(1, n_iter+1), val_f1s,      marker="s", linestyle="--", label="Val F1")
plt.xlabel("Epoch")
plt.title("Training Loss & Validation F1")
plt.legend()
plt.tight_layout()
plt.savefig("spacy_xml_training_validation_curve.png")
plt.close()

# Confusion matrix
cm = confusion_matrix(test_labels, test_bin)
disp = ConfusionMatrixDisplay(cm, display_labels=["0","1"])
disp.plot(cmap="Blues", values_format="d")
plt.title("Confusion Matrix")
plt.tight_layout()
plt.savefig("confusion_matrix_spacy_xml.png")
plt.close()

# ROC curve
probs      = [p["1"] for p in test_preds]
fpr, tpr, _ = roc_curve(test_labels, probs)
roc_auc    = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.2f}")
plt.plot([0,1], [0,1], linestyle="--")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend(loc="lower right")
plt.tight_layout()
plt.savefig("roc_curve_spacy_xml.png")
plt.close()

# Save results
pd.DataFrame([scores]).to_csv("spacy_with_xml_results.csv", index=False)
print("✅ DONE — all SpaCy XML tagging + metrics saved.")

Epoch 1/10 — Train Loss: 475.7278 — Val F1: 60.5856%
Epoch 2/10 — Train Loss: 367.9199 — Val F1: 62.7451%
Epoch 3/10 — Train Loss: 309.5594 — Val F1: 60.1175%
Epoch 4/10 — Train Loss: 269.9840 — Val F1: 61.7357%
Epoch 5/10 — Train Loss: 241.7007 — Val F1: 58.7856%
Epoch 6/10 — Train Loss: 219.1602 — Val F1: 61.0338%
Epoch 7/10 — Train Loss: 203.6863 — Val F1: 59.9696%
Epoch 8/10 — Train Loss: 189.2522 — Val F1: 58.6722%
Epoch 9/10 — Train Loss: 176.6376 — Val F1: 58.2474%
Epoch 10/10 — Train Loss: 166.5384 — Val F1: 58.5139%
📊 Test scores: {'accuracy': 66.59407665505228, 'f1': 58.96201177100053, 'precision': 61.49553571428571, 'recall': 56.62898252826311}
✅ DONE — all SpaCy XML tagging + metrics saved.
