In [None]:
import pandas as pd
from datasets import DatasetDict, Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
    EvalPrediction,
)
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import torch
import numpy as np

# Load CSVs
train_df = pd.read_csv("../data/train.csv")
val_df = pd.read_csv("../data/val.csv")
test_df = pd.read_csv("../data/test.csv")

# Convert to Hugging Face Datasets
dataset = DatasetDict({
    "train": Dataset.from_pandas(train_df),
    "validation": Dataset.from_pandas(val_df),
    "test": Dataset.from_pandas(test_df)
})

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Tokenization function
def tokenize_function(example):
    return tokenizer(example["comment_text"], padding="max_length", truncation=True, max_length=128)

# Tokenize all splits
tokenized_dataset = dataset.map(tokenize_function, batched=True)
tokenized_dataset = tokenized_dataset.remove_columns(["comment_text", "__index_level_0__"])  # remove unused cols
tokenized_dataset.set_format("torch")

# Data collator (auto-padding for batches)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# Load BERT model for binary classification
model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)

# Define compute_metrics function for Trainer
def compute_metrics(p: EvalPrediction):
    preds = np.argmax(p.predictions, axis=1)
    labels = p.label_ids
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds),
        "precision": precision_score(labels, preds),
        "recall": recall_score(labels, preds)
    }

# Training configuration
training_args = TrainingArguments(
    output_dir="models/bert-toxic",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=50,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=3,
    learning_rate=2e-5,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    save_total_limit=1,
    seed=42,
    report_to="none"  # disable wandb or hub reporting
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

# Train the model
trainer.train()

# Evaluate on test set
test_results = trainer.evaluate(tokenized_dataset["test"])
print("Test results:", test_results)

# # Save final model
# trainer.save_model("models/bert-toxic")
# tokenizer.save_pretrained("models/bert-toxic")