In [1]:
from sklearn.metrics import classification_report, accuracy_score
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
import numpy as np
import torch





In [2]:
# Load dataset (IMDB for example)
dataset = load_dataset("imdb")


In [3]:
# Use a small subset for training/testing to save time
train_dataset = dataset["train"].shuffle(seed=42).select(range(2000))
test_dataset = dataset["test"].shuffle(seed=42).select(range(1000))


In [4]:
# Load tokenizer and tokenize the dataset
model_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

def tokenize_function(example):
    return tokenizer(example["text"], padding="max_length", truncation=True)

tokenized_train = train_dataset.map(tokenize_function, batched=True)
tokenized_test = test_dataset.map(tokenize_function, batched=True)


Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [5]:
# Set format for PyTorch
tokenized_train.set_format("torch", columns=["input_ids", "attention_mask", "label"])
tokenized_test.set_format("torch", columns=["input_ids", "attention_mask", "label"])


In [6]:
# Load model
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    report_to="none"
)




In [8]:
# Define compute_metrics function
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy_score(labels, predictions),
        "report": classification_report(labels, predictions, output_dict=True)
    }


In [9]:
# Create Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    compute_metrics=compute_metrics,
)


In [10]:
# Train the model
trainer.train()


Epoch,Training Loss,Validation Loss,Accuracy,Report
1,0.2229,0.347104,0.849,"{'0': {'precision': 0.941320293398533, 'recall': 0.751953125, 'f1-score': 0.8360477741585234, 'support': 512.0}, '1': {'precision': 0.7851099830795262, 'recall': 0.9508196721311475, 'f1-score': 0.8600556070435589, 'support': 488.0}, 'accuracy': 0.849, 'macro avg': {'precision': 0.8632151382390296, 'recall': 0.8513863985655737, 'f1-score': 0.8480516906010411, 'support': 1000.0}, 'weighted avg': {'precision': 0.8650896619628576, 'recall': 0.849, 'f1-score': 0.8477635966064208, 'support': 1000.0}}"
2,0.2388,0.302303,0.879,"{'0': {'precision': 0.8825831702544031, 'recall': 0.880859375, 'f1-score': 0.8817204301075269, 'support': 512.0}, '1': {'precision': 0.8752556237218814, 'recall': 0.8770491803278688, 'f1-score': 0.8761514841351075, 'support': 488.0}, 'accuracy': 0.879, 'macro avg': {'precision': 0.8789193969881423, 'recall': 0.8789542776639344, 'f1-score': 0.8789359571213171, 'support': 1000.0}, 'weighted avg': {'precision': 0.8790073275465324, 'recall': 0.879, 'f1-score': 0.8790027844729862, 'support': 1000.0}}"


TrainOutput(global_step=250, training_loss=0.3215472717285156, metrics={'train_runtime': 3824.3859, 'train_samples_per_second': 1.046, 'train_steps_per_second': 0.065, 'total_flos': 529869594624000.0, 'train_loss': 0.3215472717285156, 'epoch': 2.0})

In [11]:
# Evaluate the model
results = trainer.evaluate()
results_summary = {
    "Accuracy": results["eval_accuracy"],
    "Precision (class 0)": results["eval_report"]["0"]["precision"],
    "Recall (class 0)": results["eval_report"]["0"]["recall"],
    "F1-score (class 0)": results["eval_report"]["0"]["f1-score"],
    "Precision (class 1)": results["eval_report"]["1"]["precision"],
    "Recall (class 1)": results["eval_report"]["1"]["recall"],
    "F1-score (class 1)": results["eval_report"]["1"]["f1-score"]
}

print(results_summary)


{'Accuracy': 0.879, 'Precision (class 0)': 0.8825831702544031, 'Recall (class 0)': 0.880859375, 'F1-score (class 0)': 0.8817204301075269, 'Precision (class 1)': 0.8752556237218814, 'Recall (class 1)': 0.8770491803278688, 'F1-score (class 1)': 0.8761514841351075}
