In [1]:
from datasets import load_dataset
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments
import numpy as np
from sklearn.metrics import accuracy_score

In [15]:
# Load the AG News dataset
agnews_train = load_dataset("fancyzhx/ag_news", split="train")
agnews_test = load_dataset("fancyzhx/ag_news", split="test")

In [16]:
# Initialize the tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

In [17]:
# Tokenize the datasets
def tokenize_function(examples):
    return tokenizer(examples['text'], padding="max_length", truncation=True)

In [18]:
tokenized_train = agnews_train.map(tokenize_function, batched=True)
tokenized_test = agnews_test.map(tokenize_function, batched=True)

In [20]:
# Initialize the model
model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=4)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [19]:
# Define a compute_metrics function to calculate accuracy
def compute_metrics(p):
    preds = np.argmax(p.predictions, axis=1)
    accuracy = accuracy_score(p.label_ids, preds)
    return {"accuracy": accuracy}

In [21]:
# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    weight_decay=0.01,
)

In [22]:
# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    compute_metrics=compute_metrics,
)

In [23]:
# Train the model
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.1971,0.17594,0.943158


TrainOutput(global_step=7500, training_loss=0.22699112752278647, metrics={'train_runtime': 1599.7664, 'train_samples_per_second': 75.011, 'train_steps_per_second': 4.688, 'total_flos': 1.589665480704e+16, 'train_loss': 0.22699112752278647, 'epoch': 1.0})

In [24]:
# Evaluate the model
eval_results = trainer.evaluate()

In [25]:
print(f"Evaluation results: {eval_results}")

Evaluation results: {'eval_loss': 0.17593984305858612, 'eval_accuracy': 0.9431578947368421, 'eval_runtime': 34.8468, 'eval_samples_per_second': 218.098, 'eval_steps_per_second': 13.631, 'epoch': 1.0}
