In [None]:
# from https://github.com/jackhhao/llm-warden/blob/main/src/train.py

In [1]:
from datasets import load_dataset, ClassLabel
from transformers import (
    BertTokenizer,
    BertForSequenceClassification,
    TrainingArguments,
    Trainer
)
import numpy as np
import evaluate

labels = ClassLabel(names=["benign", "jailbreak"])

# prepare and tokenize dataset
dataset = load_dataset("jackhhao/jailbreak-classification").rename_column("prompt", "text").rename_column("type", "label")
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

# map labels to IDs
id2label = {0: "benign", 1: "jailbreak"}
label2id = {value: key for key,value in id2label.items()}

def tokenize_function(examples):
    tokenized = tokenizer(examples["text"], padding="max_length", truncation=True)
    tokenized['label'] = labels.str2int(examples['label'])
    return tokenized

tokenized_datasets = dataset.map(tokenize_function, batched=True)

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42)
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1044
    })
    test: Dataset({
        features: ['text', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 262
    })
})

In [3]:

# set up evaluation 
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

# load pretrained model and evaluate model after each epoch
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=2,
    id2label=id2label,
    label2id=label2id
)
training_args = TrainingArguments(
    output_dir="../training/",
    num_train_epochs=2,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()

trainer.save_model("../model/")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                 
 20%|██        | 131/655 [03:12<11:37,  1.33s/it]

{'eval_loss': 0.047433920204639435, 'eval_accuracy': 0.9847328244274809, 'eval_runtime': 12.4056, 'eval_samples_per_second': 21.119, 'eval_steps_per_second': 2.66, 'epoch': 1.0}


                                                 
 40%|████      | 262/655 [06:25<07:54,  1.21s/it]

{'eval_loss': 0.022801201790571213, 'eval_accuracy': 0.9961832061068703, 'eval_runtime': 11.88, 'eval_samples_per_second': 22.054, 'eval_steps_per_second': 2.778, 'epoch': 2.0}


                                                 
 60%|██████    | 393/655 [09:37<05:13,  1.20s/it]

{'eval_loss': 0.026989035308361053, 'eval_accuracy': 0.9923664122137404, 'eval_runtime': 11.7122, 'eval_samples_per_second': 22.37, 'eval_steps_per_second': 2.818, 'epoch': 3.0}


 76%|███████▋  | 500/655 [12:05<03:29,  1.35s/it]

{'loss': 0.0694, 'grad_norm': 0.003957709297537804, 'learning_rate': 1.1832061068702292e-05, 'epoch': 3.82}


                                                 
 80%|████████  | 524/655 [12:48<02:36,  1.19s/it]

{'eval_loss': 0.06803369522094727, 'eval_accuracy': 0.9885496183206107, 'eval_runtime': 11.7054, 'eval_samples_per_second': 22.383, 'eval_steps_per_second': 2.819, 'epoch': 4.0}


                                                 
100%|██████████| 655/655 [16:03<00:00,  1.20s/it]

{'eval_loss': 0.044569749385118484, 'eval_accuracy': 0.9923664122137404, 'eval_runtime': 11.8713, 'eval_samples_per_second': 22.07, 'eval_steps_per_second': 2.78, 'epoch': 5.0}


100%|██████████| 655/655 [16:05<00:00,  1.47s/it]


{'train_runtime': 965.6463, 'train_samples_per_second': 5.406, 'train_steps_per_second': 0.678, 'train_loss': 0.05352991091386052, 'epoch': 5.0}
