In [None]:
!pip install datasets



In [None]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import torch

In [None]:
newsgroups = fetch_20newsgroups(subset="all")
X, y = newsgroups.data, newsgroups.target
train_texts, test_texts, train_labels, test_labels = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
from datasets import Dataset, DatasetDict

train_dataset = Dataset.from_dict({"text": train_texts, "label": train_labels})
test_dataset = Dataset.from_dict({"text": test_texts, "label": test_labels})

datasets = DatasetDict({"train": train_dataset, "test": test_dataset})

In [None]:
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=20)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length")


tokenized_datasets = datasets.map(preprocess_function, batched=True)

Map:   0%|          | 0/15076 [00:00<?, ? examples/s]

Map:   0%|          | 0/3770 [00:00<?, ? examples/s]

In [None]:
# tokenized_datasets.set_format('torch', columns=['input_ids', 'attention_mask'])

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
)

trainer.train()

Epoch,Training Loss,Validation Loss
1,0.9399,0.509467
2,0.3255,0.409498
3,0.1603,0.400731


Epoch,Training Loss,Validation Loss
1,0.9399,0.509467
2,0.3255,0.409498
3,0.1603,0.400731
4,0.0779,0.396493
5,0.0362,0.409647


TrainOutput(global_step=4715, training_loss=0.2658366144651573, metrics={'train_runtime': 4207.9434, 'train_samples_per_second': 17.914, 'train_steps_per_second': 1.12, 'total_flos': 9988597866086400.0, 'train_loss': 0.2658366144651573, 'epoch': 5.0})

In [None]:
predictions = trainer.predict(tokenized_datasets['test'])
pred_labels = predictions.predictions.argmax(-1)
accuracy = accuracy_score(test_labels, pred_labels)
print(f"Accuracy: {accuracy}")

# Classification report
print(classification_report(test_labels, pred_labels, target_names=newsgroups.target_names))

Accuracy: 0.9249336870026525
                          precision    recall  f1-score   support

             alt.atheism       0.90      0.94      0.92       151
           comp.graphics       0.88      0.87      0.88       202
 comp.os.ms-windows.misc       0.85      0.87      0.86       195
comp.sys.ibm.pc.hardware       0.72      0.80      0.76       183
   comp.sys.mac.hardware       0.88      0.90      0.89       205
          comp.windows.x       0.96      0.92      0.94       215
            misc.forsale       0.90      0.83      0.86       193
               rec.autos       0.96      0.96      0.96       196
         rec.motorcycles       0.95      0.97      0.96       168
      rec.sport.baseball       0.99      0.99      0.99       211
        rec.sport.hockey       0.98      0.98      0.98       198
               sci.crypt       0.97      0.96      0.96       201
         sci.electronics       0.90      0.89      0.89       202
                 sci.med       0.99      0.97 