In [1]:
from transformers import pipeline, DistilBertTokenizer, DistilBertForSequenceClassification, TrainingArguments, Trainer
from datasets import load_dataset
import os

In [2]:
os.environ['HF_HUB_DISABLE_SYMLINKS_WARNING'] = '1'

In [3]:
# Load the classifier with ready-to-use model
classifier = pipeline('text-classification', model='distilbert-base-uncased-finetuned-sst-2-english')

In [4]:
text = "Apple unveils new products at its annual technology conference."
result = classifier(text)
print(result)

[{'label': 'POSITIVE', 'score': 0.9976915121078491}]


In [5]:
dataset = load_dataset("ag_news")

In [6]:
# Print the first 5 samples
for i in range(5):
    print(dataset["train"][i])

{'text': "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.", 'label': 2}
{'text': 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.', 'label': 2}
{'text': "Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\\about the economy and the outlook for earnings are expected to\\hang over the stock market next week during the depth of the\\summer doldrums.", 'label': 2}
{'text': 'Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\\flows from the main pipeline in southern Iraq after\\intelligence showed a rebel militia could strike\\infrastructure, an oil official said on Saturday.', 'lab

In [7]:
# Print label names
print(dataset["train"].features["label"])

ClassLabel(names=['World', 'Sports', 'Business', 'Sci/Tech'], id=None)


In [8]:
# Print length of the training set and the test set
print(len(dataset["train"]), len(dataset["test"]))

120000 7600


In [9]:
# Asign label names to the labels for the first 5 samples in the training set
for i in range(5):
    print(dataset["train"][i]["label"], dataset["train"].features["label"].int2str(dataset["train"][i]["label"]))

2 Business
2 Business
2 Business
2 Business
2 Business


In [10]:
# Adjust loaded data to the format required by the distilbert model

# Initialize the tokenizer
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

# Tokenize the function
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

# Tokenize the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [11]:
# Set the amount of categories for the classification
num_labels = 4

# Load the model for classification
model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=num_labels)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [13]:
# Define the training arguments
training_args = TrainingArguments(
    output_dir="./results",               # Folder for the results
    eval_strategy="epoch",          # Evaluate after each epoch
    learning_rate=0.00002,                   # Learning rate
    per_device_train_batch_size=16,        # Size of the batch for training
    per_device_eval_batch_size=16,         # Size of the batch for evaluation
    num_train_epochs=3,                   # Number of epochs
    weight_decay=0.01,                    # Regularization parameter
)

In [14]:
trainer = Trainer(
    model=model,                           # DistilBert model
    args=training_args,                    # Training arguments
    train_dataset=tokenized_datasets["train"],  # Training dataset
    eval_dataset=tokenized_datasets["test"]     # Evaluation dataset
)

In [15]:
trainer.train()

Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 