In [1]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
import torch

In [2]:
newsgroups = fetch_20newsgroups(data_home="../data", subset="all")
X, y = newsgroups.data, newsgroups.target
train_texts, test_texts, train_labels, test_labels = train_test_split(X, y, test_size=0.2, random_state=42)

In [3]:
from datasets import Dataset, DatasetDict

train_dataset = Dataset.from_dict({"text": train_texts, "label": train_labels})
test_dataset = Dataset.from_dict({"text": test_texts, "label": test_labels})

datasets = DatasetDict({"train": train_dataset, "test": test_dataset})

In [4]:
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=20)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length")


tokenized_datasets = datasets.map(preprocess_function, batched=True)

Map:   0%|          | 0/15076 [00:00<?, ? examples/s]

Map:   0%|          | 0/3770 [00:00<?, ? examples/s]

In [6]:
# tokenized_datasets.set_format('torch', columns=['input_ids', 'attention_mask'])

In [7]:
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
)

trainer.train()

  0%|          | 0/4715 [00:00<?, ?it/s]

{'loss': 1.1437, 'grad_norm': 5.092325210571289, 'learning_rate': 0.00017879109225874869, 'epoch': 0.53}


  0%|          | 0/236 [00:00<?, ?it/s]

{'eval_loss': 0.6672465801239014, 'eval_runtime': 43.1297, 'eval_samples_per_second': 87.411, 'eval_steps_per_second': 5.472, 'epoch': 1.0}
{'loss': 0.7163, 'grad_norm': 7.7566914558410645, 'learning_rate': 0.00015758218451749736, 'epoch': 1.06}
{'loss': 0.4633, 'grad_norm': 10.152420997619629, 'learning_rate': 0.00013637327677624604, 'epoch': 1.59}


  0%|          | 0/236 [00:00<?, ?it/s]

{'eval_loss': 0.44774916768074036, 'eval_runtime': 42.9614, 'eval_samples_per_second': 87.753, 'eval_steps_per_second': 5.493, 'epoch': 2.0}
{'loss': 0.3783, 'grad_norm': 11.205493927001953, 'learning_rate': 0.0001151643690349947, 'epoch': 2.12}
{'loss': 0.2507, 'grad_norm': 0.25099655985832214, 'learning_rate': 9.395546129374337e-05, 'epoch': 2.65}


  0%|          | 0/236 [00:00<?, ?it/s]

{'eval_loss': 0.45024561882019043, 'eval_runtime': 43.0721, 'eval_samples_per_second': 87.528, 'eval_steps_per_second': 5.479, 'epoch': 3.0}
{'loss': 0.186, 'grad_norm': 0.14138658344745636, 'learning_rate': 7.274655355249205e-05, 'epoch': 3.18}
{'loss': 0.1366, 'grad_norm': 0.14014065265655518, 'learning_rate': 5.153764581124072e-05, 'epoch': 3.71}


  0%|          | 0/236 [00:00<?, ?it/s]

{'eval_loss': 0.46228086948394775, 'eval_runtime': 43.6422, 'eval_samples_per_second': 86.384, 'eval_steps_per_second': 5.408, 'epoch': 4.0}
{'loss': 0.0709, 'grad_norm': 17.50982666015625, 'learning_rate': 3.0328738069989398e-05, 'epoch': 4.24}
{'loss': 0.0554, 'grad_norm': 1.0314443111419678, 'learning_rate': 9.11983032873807e-06, 'epoch': 4.77}


  0%|          | 0/236 [00:00<?, ?it/s]

{'eval_loss': 0.5045596957206726, 'eval_runtime': 43.897, 'eval_samples_per_second': 85.883, 'eval_steps_per_second': 5.376, 'epoch': 5.0}
{'train_runtime': 2775.9224, 'train_samples_per_second': 27.155, 'train_steps_per_second': 1.699, 'train_loss': 0.3633884314641214, 'epoch': 5.0}


TrainOutput(global_step=4715, training_loss=0.3633884314641214, metrics={'train_runtime': 2775.9224, 'train_samples_per_second': 27.155, 'train_steps_per_second': 1.699, 'total_flos': 9988597866086400.0, 'train_loss': 0.3633884314641214, 'epoch': 5.0})

In [8]:
predictions = trainer.predict(tokenized_datasets['test'])
pred_labels = predictions.predictions.argmax(-1)
accuracy = accuracy_score(test_labels, pred_labels)
print(f"Accuracy: {accuracy}")

# Classification report
print(classification_report(test_labels, pred_labels, target_names=newsgroups.target_names))

  0%|          | 0/236 [00:00<?, ?it/s]

Accuracy: 0.9122015915119364
                          precision    recall  f1-score   support

             alt.atheism       0.89      0.94      0.92       151
           comp.graphics       0.87      0.83      0.85       202
 comp.os.ms-windows.misc       0.82      0.88      0.85       195
comp.sys.ibm.pc.hardware       0.74      0.75      0.75       183
   comp.sys.mac.hardware       0.87      0.91      0.89       205
          comp.windows.x       0.97      0.89      0.93       215
            misc.forsale       0.87      0.86      0.87       193
               rec.autos       0.90      0.90      0.90       196
         rec.motorcycles       0.95      0.95      0.95       168
      rec.sport.baseball       0.99      0.98      0.98       211
        rec.sport.hockey       0.97      0.98      0.97       198
               sci.crypt       0.96      0.95      0.95       201
         sci.electronics       0.93      0.85      0.89       202
                 sci.med       0.96      0.96 

In [3]:
a = [(1,2), (3,4)]

list(zip(*a))

[(1, 3), (2, 4)]