In [3]:
from datasets import load_dataset, concatenate_datasets, DatasetDict
from transformers import AutoTokenizer
from dataset_processing import datasets_config, sample_dataset

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
TOXICITY_THRESHOLD = 0.7
toxicchat_train = None 
toxicchat_test = None 
datasets = {}
required_columns = ["input_ids", "attention_mask", "label"] 

for config in datasets_config:
    print(f"Processing {config['name']}...")

    if config["name"] == "lmsys/toxic-chat":
        # Load the specific subset of ToxicChat
        raw_dataset = load_dataset(config["name"], config["subset"])

        toxicchat_train = raw_dataset["train"].map(config["preprocess_function"], fn_kwargs={"tokenizer": tokenizer}, batched=True)
        toxicchat_test = raw_dataset["test"].map(config["preprocess_function"], fn_kwargs={"tokenizer": tokenizer}, batched=True)
        toxicchat_train = toxicchat_train.remove_columns([col for col in toxicchat_train.column_names if col not in required_columns])
        toxicchat_test = toxicchat_test.remove_columns([col for col in toxicchat_test.column_names if col not in required_columns])
        datasets[config["name"]] = sample_dataset(toxicchat_test, sample_size=1667)

    else:
        # Load and preprocess the train split of other datasets
        raw_dataset = load_dataset(config["name"], split=config["split"])
        tokenized_dataset = raw_dataset.map(config["preprocess_function"], fn_kwargs={"tokenizer": tokenizer}, batched=True)
        tokenized_dataset = tokenized_dataset.remove_columns([col for col in tokenized_dataset.column_names if col not in required_columns])
        sampled_dataset = sample_dataset(tokenized_dataset, sample_size=3333)
        datasets[config["name"]] = sampled_dataset.train_test_split(test_size=0.5, seed=1337)["test"]


datasets

Processing lmsys/toxic-chat...


Map:   0%|          | 0/5082 [00:00<?, ? examples/s]

Map:   0%|          | 0/5083 [00:00<?, ? examples/s]

Processing allenai/real-toxicity-prompts...


Map:   0%|          | 0/99442 [00:00<?, ? examples/s]

Processing tasksource/jigsaw_toxicity...


Map:   0%|          | 0/159571 [00:00<?, ? examples/s]

Processing google/civil_comments...


Map:   0%|          | 0/1804874 [00:00<?, ? examples/s]

{'lmsys/toxic-chat': Dataset({
     features: ['input_ids', 'attention_mask', 'label'],
     num_rows: 1667
 }),
 'allenai/real-toxicity-prompts': Dataset({
     features: ['input_ids', 'attention_mask', 'label'],
     num_rows: 1667
 }),
 'tasksource/jigsaw_toxicity': Dataset({
     features: ['input_ids', 'attention_mask', 'label'],
     num_rows: 1667
 }),
 'google/civil_comments': Dataset({
     features: ['input_ids', 'attention_mask', 'label'],
     num_rows: 1667
 })}

# Evaluation

In [None]:

from transformers import DataCollatorWithPadding, AutoModelForSequenceClassification
from metrics import evaluate_model

tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
collator = DataCollatorWithPadding(tokenizer=tokenizer)
output_dir = "ablation"

def evaluate_and_store_metrics(model_name: str, hf_repo: str = "inxoy"):
    model = AutoModelForSequenceClassification.from_pretrained(f"{hf_repo}/{model_name}")

    metrics = {name: evaluate_model(model, dataset, collator, tokenizer) for name, dataset in datasets.items()}

    # Print summary 
    for k, m in metrics.items():
        print(f"| {k} | {m['accuracy']:.4f} | {m['precision']:.4f} | {m['recall']:.4f} | {m['f1']:.4f} |")

    with open(f"{output_dir}/{model_name}.pkl", "wb") as f:
        pickle.dump(metrics, f)


## Mixed

In [None]:
evaluate_and_store_metrics("distilbert-mixed")

## Jigsaw

In [None]:
evaluate_and_store_metrics("distilbert-jsaw")

## Civil Comments

In [None]:
evaluate_and_store_metrics("distilbert-cc")


# ToxicChat

In [None]:
evaluate_and_store_metrics("distilbert-tc")

# RealToxicityPrompts

In [None]:
evaluate_and_store_metrics("distilbert-rtp")

In [None]:
with open("ablation/distilbert_jsaw.pkl", "rb") as f:
    metrics_dict = pickle.load(f)
    for k, m in metrics_dict.items():
        print(f"| {k} | {m['accuracy']:.4f} | {m['precision']:.4f} | {m['recall']:.4f} | {m['f1']:.4f} |")

    print()

# Combined Evaluation

In [None]:
hf_repo = "inxoy"
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
collator = DataCollatorWithPadding(tokenizer=tokenizer)
output_dir = "ablation_mixed"
models = ["distilbert-jsaw", "distilbert-rtp", "distilbert-tc", "distilbert-mixed", "distilbert-cc"]

ds = load_dataset("inxoy/toxicbench-mixed")

for model_name in models:
    model = AutoModelForSequenceClassification.from_pretrained(f"{hf_repo}/{model_name}")
    metrics = evaluate_model(model, dataset, collator, tokenizer)
    print(f"| {k} | {metrics['accuracy']:.4f} | {metrics['precision']:.4f} | {metrics['recall']:.4f} | {metrics['f1']:.4f} |")

    with open(f"{output_dir}/{model_name}.pkl", "wb") as f:
        pickle.dump(metrics, f)

In [None]:
with open("ablation_mixed/distilbert-cc.pkl", 'rb+') as f:
    print(pickle.load(f))

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(f"inxoy/distilbert-cc")
metrics = evaluate_model(model, dataset, collator, tokenizer)
print(f"| {k} | {metrics['accuracy']:.4f} | {metrics['precision']:.4f} | {metrics['recall']:.4f} | {metrics['f1']:.4f} |")

with open(f"ablation_mixed/distilbert-cc.pkl", "wb") as f:
    pickle.dump(metrics, f)

In [None]:
with open(f"ablation_mixed/distilbert-jsaw.pkl", "rb+") as f:
    print(pickle.load(f))

In [None]:
with open(f"ablation_mixed/distilbert-mixed.pkl", "rb+") as f:
    print(pickle.load(f))

In [None]:
with open(f"ablation_mixed/distilbert-rtp.pkl", "rb+") as f:
    print(pickle.load(f))

# Ablation

In [None]:
with open(f"ablation/distilbert_jsaw.pkl", "rb+") as f:
    print(pickle.load(f))

In [None]:
with open(f"ablation/distilbert_tc.pkl", "rb+") as f:
    print(pickle.load(f))

In [None]:
with open(f"ablation/distilbert_mixed.pkl", "rb+") as f:
    print(pickle.load(f))

In [None]:
with open(f"ablation/distilbert-rtp.pkl", "rb+") as f:
    print(pickle.load(f))

In [None]:
with open(f"ablation/distilbert-cc.pkl", "rb+") as f:
    print(pickle.load(f))