In [None]:
from pathlib import Path
import os, pandas as pd
from datasets import Dataset, load_from_disk
from transformers import AutoTokenizer
from sklearn.feature_extraction.text import ENGLISH_STOP_WORDS

PROJECT_ROOT = Path.cwd()
DATA_IN = PROJECT_ROOT / "data" / "cleaned_data.csv"
DATA_OUT_DIR = PROJECT_ROOT / "data" / "tokenized_bert_uncased_max128"
MAX_LEN = 128
REMOVE_STOPWORDS = False

assert DATA_IN.exists(), f"Missing file: {DATA_IN}"


df = pd.read_csv(DATA_IN, usecols=["comment", "bias_sent", "category"]).dropna(subset=["comment","bias_sent"])

labels_num = pd.to_numeric(df["bias_sent"], errors="coerce")
mask = labels_num.isin([0, 1])
dropped = (~mask | labels_num.isna()).sum()
df = df.loc[mask, ["comment", "category"]].copy()
df["labels"] = labels_num.loc[mask].astype(int)


def maybe_remove_stopwords(text: str) -> str:
    if not REMOVE_STOPWORDS:
        return str(text)
    toks = [w for w in str(text).split() if w.lower() not in ENGLISH_STOP_WORDS]
    return " ".join(toks)

df["text"] = df["comment"].astype(str).apply(maybe_remove_stopwords)

print(f"Rows kept: {len(df)} | Dropped (invalid labels): {int(dropped)}")

# tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)


ds = Dataset.from_pandas(df[["text", "labels", "category"]], preserve_index=False)

def tok(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=MAX_LEN)

ds = ds.map(tok, batched=True, desc="Tokenizing")

# keep standard columns
keep = ["input_ids", "attention_mask", "labels", "category"]
if "token_type_ids" in ds.column_names:
    keep.insert(1, "token_type_ids")
ds = ds.select_columns(keep)


DATA_OUT_DIR.mkdir(parents=True, exist_ok=True)
ds.save_to_disk(str(DATA_OUT_DIR))
print(f"Saved tokenized dataset to: {DATA_OUT_DIR}")


In [None]:
#reloading the saved dataset from disk to confirm its readable
#Prints a summary object that shows columns and number of rows

loaded = load_from_disk(str(DATA_OUT_DIR))
print(loaded) 
row0 = {k: loaded[k][0] for k in ["labels","category"]}
print("Row 0 label/category:", row0)
max_len = max(len(ids) for ids in loaded["input_ids"]) if "input_ids" in loaded.column_names else 0
print("Max seq len:", max_len) 
