In [1]:
import torch
from torch.utils.data import Dataset
from torch_geometric.loader import DataLoader
from torch.nn import CrossEntropyLoss

import sys
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
sys.path.append('../src/')

import numpy as np

from GNNs.GCNClassifier import GCNClassifier
from GNNs.GRNClassifier import GRNClassifier

from embeddings.word2vec import Word2VecEmbedding
from embeddings.sbert import SBERTEmbedding

from utils.graph_of_words import GraphOfWords
from utils.graph_to_data import GraphToData
from utils.dataset_wrapper import DatasetWrapper


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
import json
from pathlib import Path

with open('../models/hateXplain/dataset.json', 'r', encoding='utf-8') as f:
    raw_data = json.load(f)

In [3]:
from sklearn.preprocessing import LabelEncoder

encoder = LabelEncoder()
encoder.classes_ = np.load('../models/hateXplain/classes.npy', allow_pickle=True)


In [4]:
from collections import Counter

def get_majority_label(annotators):
    labels = [a['label'] for a in annotators]
    most_common = Counter(labels).most_common(1)
    return most_common[0][0] if most_common else None

texts, labels = [], []

for post_id, post_data in raw_data.items():
    majority_label = get_majority_label(post_data['annotators'])
    if majority_label in encoder.classes_:
        label_id = encoder.transform([majority_label])[0]
        text = ' '.join(post_data['post_tokens'])
        texts.append(text)
        labels.append(label_id)

print(f'Loaded {len(texts)} samples.')

# Print the first 5 samples
for i in range(5):
    print(f'Text: {texts[i]}')
    print(f'Label: {labels[i]}')
    print()

Loaded 20148 samples.
Text: i dont think im getting my baby them white 9 he has two white j and nikes not even touched
Label: 1

Text: we cannot continue calling ourselves feminists if the rights of all womxn arent addressed yes to a sexual offences public list but will a trans lesbian bisexual and queer womxn be able to enter their information on the reporting sheet gender forum
Label: 1

Text: nawt yall niggers ignoring me
Label: 1

Text: <user> i am bit confused coz chinese ppl can not access twitter thn how this ching chong using it i think he pakistani 🤔 🤔 🤔
Label: 0

Text: this bitch in whataburger eating a burger with no top bun and holding it palm up i hate white bitches
Label: 0



In [5]:
from sklearn.model_selection import train_test_split

train_texts, test_texts, train_labels, test_labels = train_test_split(
    texts, labels, test_size=0.2, random_state=42, stratify=labels
)

val_texts, test_texts, val_labels, test_labels = train_test_split(
    test_texts, test_labels, test_size=0.5, random_state=42, stratify=test_labels
)


In [6]:
w2v_embedder = Word2VecEmbedding('../models/google/GoogleNews-vectors-negative300.kv', device=device)
sbert_embedder = SBERTEmbedding(device=device)
gow = GraphOfWords(embedding_model=sbert_embedder, window_size=2)
text_to_graph = GraphToData(gow)

In [7]:
from torch_geometric.loader import DataLoader

train_dataset = DatasetWrapper(train_texts, train_labels, text_to_graph)
val_dataset = DatasetWrapper(val_texts, val_labels, text_to_graph)
test_dataset = DatasetWrapper(test_texts, test_labels, text_to_graph)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)


In [8]:
print("Train dataset size:", len(train_dataset))

for i, data in enumerate(train_dataset):
    if data is None:
        print(f"Entry {i} is None")


all_labels = [data.y.item() for data in train_dataset]
print("Unique labels in training set:", set(all_labels))

Train dataset size: 16118
Unique labels in training set: {0, 1, 2}


In [9]:
from sklearn.metrics import (
    f1_score,
    roc_auc_score,
    precision_recall_fscore_support,
    confusion_matrix,
    classification_report
)

def train(model, loader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0
    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = loss_fn(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, device, label_names=None):
    model.eval()
    all_preds = []
    all_targets = []
    all_probs = []

    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data)  # logits
            probs = torch.softmax(out, dim=1)
            preds = out.argmax(dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_targets.extend(data.y.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    # Convert to numpy arrays
    y_true = np.array(all_targets)
    y_pred = np.array(all_preds)
    y_probs = np.array(all_probs)

    # Global metrics
    f1 = f1_score(y_true, y_pred, average='macro')
    try:
        roc_auc = roc_auc_score(y_true, y_probs, multi_class='ovr', average='macro')
    except ValueError:
        roc_auc = float('nan')

    # Per-class metrics
    precision, recall, f1s, support = precision_recall_fscore_support(
        y_true, y_pred, average=None, zero_division=0
    )

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)

    print("\n--- Evaluation Metrics ---")
    print(f"Macro F1 Score: {f1:.4f}")
    print(f"Macro ROC AUC: {roc_auc:.4f}")
    print("\nClassification Report:")
    print(classification_report(y_true, y_pred, target_names=label_names or [str(i) for i in range(len(precision))]))

    print("Confusion Matrix:")
    print(cm)

    return f1, roc_auc

model = GCNClassifier(in_channels=384, hidden_channels=128, num_classes=3).to(device) # Adjust input size when changing embedder!!!!
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = CrossEntropyLoss()

# Map class indices to names
label_names = ["normal", "offensive", "hatespeech"]

for epoch in range(1, 11):
    train_loss = train(model, train_loader, optimizer, loss_fn, device)
    f1, roc_auc = evaluate(model, test_loader, device, label_names=label_names)
    print(f"Epoch {epoch:02d} | Train Loss: {train_loss:.4f} | F1: {f1:.4f} | ROC-AUC: {roc_auc:.4f}")
    print("=====================================")




--- Evaluation Metrics ---
Macro F1 Score: 0.5315
Macro ROC AUC: 0.7517

Classification Report:
              precision    recall  f1-score   support

      normal       0.66      0.62      0.64       624
   offensive       0.55      0.84      0.67       815
  hatespeech       0.61      0.19      0.29       576

    accuracy                           0.59      2015
   macro avg       0.61      0.55      0.53      2015
weighted avg       0.60      0.59      0.55      2015

Confusion Matrix:
[[387 207  30]
 [ 88 686  41]
 [115 351 110]]
Epoch 01 | Train Loss: 0.9843 | F1: 0.5315 | ROC-AUC: 0.7517

--- Evaluation Metrics ---
Macro F1 Score: 0.5759
Macro ROC AUC: 0.7664

Classification Report:
              precision    recall  f1-score   support

      normal       0.73      0.57      0.64       624
   offensive       0.59      0.78      0.67       815
  hatespeech       0.48      0.37      0.41       576

    accuracy                           0.60      2015
   macro avg       0.60     