<a href="https://colab.research.google.com/github/goerlitz/nlp-classification/blob/main/notebooks/10kGNAD/colab/21_10kGNAD_simpletransformers_default.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Classifying German News Articles with SimpleTransformers

## Objectives

1. Train a text classifier with transfer learning based on a pretrained German transformer model.
2. Keep the implementation simple (just a few lines of code) by using the SimpleTransformers library. It also has sensible default model settings.


## Approach

Use following pretrained models on the 10k German News Articles dataset to classify 9 news topics.

* `distilbert-base-german-cased`
* `deepset/gbert-base`
* `deepset/gelectra-large`

## Learnings

...

## Prerequisites

In [1]:
model_type = "distilbert"
model_name = "distilbert-base-german-cased"

# model_type = "bert"
# model_name = "deepset/gbert-base"

# model_type = "electra"
# model_name = "deepset/gelectra-base"

project_name = "10kgnad_default__" + model_name.replace("/", "_")

In [2]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Sun Nov 14 21:11:52 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 495.44       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   51C    P0    30W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [4]:
# install transformers
!pip install -q -U tqdm transformers simpletransformers >/dev/null

# check installed version
!pip freeze | grep transformers
!pip freeze | grep torch
# simpletransformers==0.61.6 / 0.63.3
# transformers==4.6.1 / 4.12.3
# torch==1.8.1+cu101 / 1.10.0

simpletransformers==0.63.3
transformers==4.12.3
torch @ https://download.pytorch.org/whl/cu111/torch-1.10.0%2Bcu111-cp37-cp37m-linux_x86_64.whl
torchsummary==1.5.1
torchtext==0.11.0
torchvision @ https://download.pytorch.org/whl/cu111/torchvision-0.11.1%2Bcu111-cp37-cp37m-linux_x86_64.whl


In [5]:
import numpy as np
import pandas as pd
from pathlib import Path
import os

from simpletransformers.classification import ClassificationModel
from transformers import AutoTokenizer
from transformers import logging
import wandb

# hide progress bar when downloading tokenizers - a workaround!
logging.get_verbosity = lambda : logging.NOTSET

# disable transformer warnings like "Some weights of the model checkpoint"
logging.set_verbosity_error()

# disable logging of wandb
os.environ["WANDB_SILENT"] = "true"

## Download Data

Using the [10k German News Articles Dataset](https://tblock.github.io/10kGNAD/)

In [6]:
%env DIR=data

!mkdir -p $DIR
!wget -nc https://github.com/tblock/10kGNAD/blob/master/train.csv?raw=true -nv -O $DIR/train.csv
!wget -nc https://github.com/tblock/10kGNAD/blob/master/test.csv?raw=true -nv -O $DIR/test.csv
!ls -lAh $DIR | cut -d " " -f 5-

env: DIR=data
2021-11-14 21:15:01 URL:https://raw.githubusercontent.com/tblock/10kGNAD/master/train.csv [24405789/24405789] -> "data/train.csv" [1]
2021-11-14 21:15:03 URL:https://raw.githubusercontent.com/tblock/10kGNAD/master/test.csv [2755020/2755020] -> "data/test.csv" [1]

2.7M Nov 14 21:15 test.csv
 24M Nov 14 21:15 train.csv


## Import Data

In [7]:
data_dir = Path("data/")

train_file = data_dir / 'train.csv'
test_file = data_dir / 'test.csv'

def read_csv_10kGNAD(filepath: Path, columns=["labels", "text"]) -> pd.DataFrame:
    """Load 10kGNAD csv file, handling its specific file format."""
    f = pd.read_csv(filepath, sep=";", quotechar="'", names=columns)
    return f

In [37]:
train_df = read_csv_10kGNAD(data_dir / 'train.csv')
print(train_df.shape[0], 'articles')
display(train_df.head())

9245 articles


Unnamed: 0,labels,text
0,Sport,21-Jähriger fällt wohl bis Saisonende aus. Wie...
1,Kultur,"Erfundene Bilder zu Filmen, die als verloren g..."
2,Web,Der frischgekürte CEO Sundar Pichai setzt auf ...
3,Wirtschaft,"Putin: ""Einigung, dass wir Menge auf Niveau vo..."
4,Inland,Estland sieht den künftigen österreichischen P...


In [38]:
test_df = read_csv_10kGNAD(data_dir / 'test.csv')
print(test_df.shape[0], 'articles')
display(test_df.head())

1028 articles


Unnamed: 0,labels,text
0,Wirtschaft,"Die Gewerkschaft GPA-djp lanciert den ""All-in-..."
1,Sport,Franzosen verteidigen 2:1-Führung – Kritische ...
2,Web,Neues Video von Designern macht im Netz die Ru...
3,Sport,23-jähriger Brasilianer muss vier Spiele pausi...
4,International,Aufständische verwendeten Chemikalie bei Gefec...


## Prepare Data for Model Training

There are a few requiremenf for feeding training data into SimpleTransformers:

* columns should be labeled `labels` and `text` (already done when reading the data)
* labels must be encoded as int values (starting at `0`!)

Additionally, we can handle imbalanced datasets by

* computing class weights for training

### Label Encoding

In [39]:
from sklearn.preprocessing import LabelEncoder

def encode_labels(train: pd.DataFrame, test: pd.DataFrame):
    le = LabelEncoder()

    train_labels = le.fit_transform(train.labels)
    test_labels = le.transform(test.labels)

    return train.assign(labels=train_labels), test.assign(labels=test_labels), le

# caution overwriting data
train_df, test_df, le = encode_labels(train_df, test_df)
display(train_df.head())

Unnamed: 0,labels,text
0,5,21-Jähriger fällt wohl bis Saisonende aus. Wie...
1,3,"Erfundene Bilder zu Filmen, die als verloren g..."
2,6,Der frischgekürte CEO Sundar Pichai setzt auf ...
3,7,"Putin: ""Einigung, dass wir Menge auf Niveau vo..."
4,1,Estland sieht den künftigen österreichischen P...


### Computing Class Weights (not used yet)

In [11]:
from sklearn.utils.class_weight import compute_class_weight

def class_weights(labels: pd.Series) -> pd.DataFrame:
    """Compute class weights for imbalanced data."""
    uniq_labels = labels.unique()
    counts_s = labels.value_counts().reindex(uniq_labels)
    weights = compute_class_weight("balanced", uniq_labels, labels)
    return pd.DataFrame({"count": counts_s, "weight": weights}).sort_index()

weights_df = class_weights(train_df.labels)
display(weights_df)

Unnamed: 0,count,weight
0,601,1.709188
1,913,1.125106
2,1360,0.75531
3,485,2.117984
4,1510,0.68028
5,1081,0.950252
6,1509,0.68073
7,1270,0.808836
8,516,1.990741


## Model Setup

In [12]:
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

def f1_multiclass(labels, preds):
    return f1_score(labels, preds, average='macro')

def precision_multiclass(labels, preds):
    return precision_score(labels, preds, average='macro')

def recall_multiclass(labels, preds):
    return recall_score(labels, preds, average='macro')

In [15]:
train_args = {
    "reprocess_input_data": True,
    "overwrite_output_dir": True,
    "evaluate_during_training": True,
    "evaluate_during_training_steps": 200,    
    "evaluate_during_training_verbose": False,
    "evaluate_during_training_silent": True,
    "silent": True,
    "fp16": False,
    "wandb_project": project_name,
    }

In [16]:
def train():

    # need to create a tokenizer first and adjust train args with tokenizer's lower case setting
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    args = {**train_args, **{ "do_lower_case": tokenizer.do_lower_case }}

    # Create a ClassificationModel
    model = ClassificationModel(model_type,
                                model_name,
                                num_labels=train_df.labels.nunique(),
                                args=args)


    steps, details = model.train_model(train_df,
                                       eval_df=test_df,
                                       verbose=False,
                                       f1=f1_multiclass,
                                       acc=accuracy_score,
                                       precision=precision_multiclass,
                                       recall=recall_multiclass)
    
    print(details)

    wandb.join()

In [17]:
# continuously run experiments until stopped
while True:
    train()

<IPython.core.display.Javascript object>

wandb: Paste an API key from your profile and hit enter: ··········
defaultdict(<class 'list'>, {'global_step': [200, 400, 600, 800, 1000, 1156], 'train_loss': [0.9056264162063599, 0.502555787563324, 1.1973752975463867, 0.19918952882289886, 1.3784610033035278, 0.019147220999002457], 'mcc': [0.7810989186463829, 0.8211842141156012, 0.8550705617234949, 0.8630530762626466, 0.8662400023279396, 0.8728953828389182], 'f1': [0.7937181876945638, 0.831150622432685, 0.8660478008815321, 0.8770207240980252, 0.8781399898967925, 0.8846282749358215], 'acc': [0.806420233463035, 0.8433852140077821, 0.8735408560311284, 0.8803501945525292, 0.8832684824902723, 0.8891050583657587], 'precision': [0.8053801727116544, 0.8536782373418692, 0.8730085811584363, 0.8785741343605326, 0.8858663863182452, 0.890758892696854], 'recall': [0.794875218744787, 0.82494317761626, 0.8621389508525419, 0.8763035541791978, 0.8723403438638082, 0.8794234920774758], 'eval_loss': [0.6587459397639415, 0.5280436280854913, 0.4176403791607

Process ForkPoolWorker-49:
Traceback (most recent call last):
  File "/usr/lib/python3.7/multiprocessing/process.py", line 297, in _bootstrap
    self.run()
  File "/usr/lib/python3.7/multiprocessing/process.py", line 99, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python3.7/multiprocessing/pool.py", line 110, in worker
    task = get()
  File "/usr/lib/python3.7/multiprocessing/queues.py", line 354, in get
    return _ForkingPickler.loads(res)
KeyboardInterrupt


AssertionError: ignored

---
## Evaluate Best Model

In [18]:
!ls -la outputs/

total 264112
drwxr-xr-x 9 root root      4096 Nov 14 21:22 .
drwxr-xr-x 1 root root      4096 Nov 14 21:19 ..
drwxr-xr-x 2 root root      4096 Nov 14 21:20 best_model
drwxr-xr-x 2 root root      4096 Nov 14 21:21 checkpoint-1000
drwxr-xr-x 2 root root      4096 Nov 14 21:22 checkpoint-1156-epoch-1
drwxr-xr-x 2 root root      4096 Nov 14 21:20 checkpoint-200
drwxr-xr-x 2 root root      4096 Nov 14 21:20 checkpoint-400
drwxr-xr-x 2 root root      4096 Nov 14 21:21 checkpoint-600
drwxr-xr-x 2 root root      4096 Nov 14 21:21 checkpoint-800
-rw-r--r-- 1 root root      1024 Nov 14 21:37 config.json
-rw-r--r-- 1 root root       165 Nov 14 21:37 eval_results.txt
-rw-r--r-- 1 root root      2685 Nov 14 21:37 model_args.json
-rw-r--r-- 1 root root 269663345 Nov 14 21:37 pytorch_model.bin
-rw-r--r-- 1 root root       112 Nov 14 21:37 special_tokens_map.json
-rw-r--r-- 1 root root       339 Nov 14 21:37 tokenizer_config.json
-rw-r--r-- 1 root root    479105 Nov 14 21:37 tokenizer.json
-rw-r--r-- 

In [19]:
# loading best model (as stored by SimpleTransformers)
# CAUTION: for some reason this seems to be the last model not the best model
model = ClassificationModel(model_type, "outputs/best_model")

In [21]:
result, model_outputs, wrong_predictions = model.eval_model(test_df, f1=f1_multiclass, acc=accuracy_score, precision=precision_multiclass, recall=recall_multiclass, wandb_log=False)
pd.Series(result)

mcc          0.864052
f1           0.876920
acc          0.881323
precision    0.878769
recall       0.876168
eval_loss    0.395464
dtype: float64

In [43]:
preds = pd.DataFrame(model_outputs, columns=le.classes_)
preds

Unnamed: 0,Etat,Inland,International,Kultur,Panorama,Sport,Web,Wirtschaft,Wissenschaft
0,-1.246303,2.165673,-1.116495,-1.765686,-0.496776,-1.936629,-0.564982,4.745736,-1.065999
1,-1.048553,-1.501588,-0.883474,-1.832461,-0.932395,6.013526,-1.149155,-1.175344,-1.375613
2,-0.315572,-0.960655,-0.957032,-1.571495,-1.350365,-2.017533,5.953548,-0.292041,-0.866429
3,-1.085535,-1.567847,-0.826758,-1.800891,-0.636021,5.930137,-1.195616,-1.261313,-1.393840
4,-1.028808,-1.209813,6.033928,-1.352837,-0.140306,-1.330870,-0.569556,-0.307684,-1.286445
...,...,...,...,...,...,...,...,...,...
1023,0.221346,-1.229265,-1.224683,-1.400338,-1.309340,-2.240077,5.789686,0.240703,-1.236409
1024,-0.861510,4.993564,0.650322,-1.765656,-0.244199,-1.658149,-0.877266,0.471486,-1.234420
1025,-0.954487,-1.386584,-0.909206,-1.684109,-0.585073,5.869841,-1.212807,-1.411802,-1.318690
1026,-1.029463,-1.563712,-0.991201,-1.817411,-0.898136,6.012200,-1.102147,-1.260115,-1.293819


In [None]:
# preds.to_csv("data/predictions.csv", index=False)

In [44]:
pred_s = pd.DataFrame(model_outputs).idxmax(axis=1)

In [45]:
import sklearn.metrics as skm
skm.confusion_matrix(test_df.labels, pred_s)

array([[ 54,   3,   1,   1,   2,   1,   3,   2,   0],
       [  1,  85,   2,   2,   6,   0,   0,   4,   2],
       [  2,   1, 130,   0,  10,   2,   1,   4,   1],
       [  1,   2,   1,  47,   1,   0,   0,   0,   2],
       [  1,   6,  10,   2, 138,   0,   1,   6,   4],
       [  0,   0,   0,   0,   2, 118,   0,   0,   0],
       [  1,   1,   1,   0,   0,   0, 159,   5,   1],
       [  0,   4,   2,   0,   6,   0,   2, 125,   2],
       [  0,   1,   3,   0,   1,   1,   0,   1,  50]])

In [47]:
print(skm.classification_report(test_df.labels, pred_s, target_names=le.classes_))

               precision    recall  f1-score   support

         Etat       0.90      0.81      0.85        67
       Inland       0.83      0.83      0.83       102
International       0.87      0.86      0.86       151
       Kultur       0.90      0.87      0.89        54
     Panorama       0.83      0.82      0.83       168
        Sport       0.97      0.98      0.98       120
          Web       0.96      0.95      0.95       168
   Wirtschaft       0.85      0.89      0.87       141
 Wissenschaft       0.81      0.88      0.84        57

     accuracy                           0.88      1028
    macro avg       0.88      0.88      0.88      1028
 weighted avg       0.88      0.88      0.88      1028

