In [1]:
from datetime import date
from pathlib import Path

import numpy as np
import pandas as pd
from simpletransformers.classification import ClassificationModel
from sklearn.metrics import classification_report, confusion_matrix, f1_score, matthews_corrcoef
import spacy
from spacy.lang.nb.stop_words import STOP_WORDS



In [2]:
today = date.today().isoformat()
SAVE_PATH = Path('model')
SAVE_PATH.mkdir(exist_ok=True)
DATA_PATH = Path('data/norec')
DATA_PATH = Path('.')

In [3]:
subset_names = ['train', 'test', 'dev']
subsets = {name: pd.read_pickle(DATA_PATH / f'norsk_kategori_{name}.pkl') for name in subset_names}

In [4]:
subsets['train'].groupby(['rating']).count()

Unnamed: 0_level_0,text
rating,Unnamed: 1_level_1
0,2681
1,14821


In [5]:
subsets['train'].head(1)

Unnamed: 0,text,rating
4848,Franz Ferdinand :\n« You Could Have It So Much...,1


In [6]:
subsets = {name: subsets[name].rename(columns={'rating': 'label'}) for name in subset_names}
subsets['train'].head(1)

Unnamed: 0,text,label
4848,Franz Ferdinand :\n« You Could Have It So Much...,1


In [31]:
model = ClassificationModel(
    'distilbert', 
    'distilbert-base-multilingual-cased', 
    num_labels=2, 
    use_cuda=False,
    weight=[7, 1]
)

Some weights of the model checkpoint at distilbert-base-multilingual-cased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-multilingual-cased and are newly initialized: ['pre_classifier.weight', 'pre_cla

In [8]:
training_args = {
    'num_train_epochs': 1,
    'learning_rate': 0.005,
    'overwrite_output_dir': True,
    'evaluate_during_training': True,
    'save_model_every_epoch': False
}
model.train_model(
    subsets['train'],
    eval_df=subsets['dev'],
    output_dir=SAVE_PATH,
    args=training_args
)

9s/it][A
Epochs 0/1. Running Loss:    0.3870:  95%|█████████▌| 2082/2188 [4:08:41<10:13,  5.79s/it][A
Epochs 0/1. Running Loss:    0.3870:  95%|█████████▌| 2083/2188 [4:08:46<10:03,  5.75s/it][A
Epochs 0/1. Running Loss:    0.3445:  95%|█████████▌| 2083/2188 [4:08:47<10:03,  5.75s/it][A
Epochs 0/1. Running Loss:    0.3445:  95%|█████████▌| 2084/2188 [4:08:51<09:51,  5.68s/it][A
Epochs 0/1. Running Loss:    0.4096:  95%|█████████▌| 2084/2188 [4:08:53<09:51,  5.68s/it][A
Epochs 0/1. Running Loss:    0.4096:  95%|█████████▌| 2085/2188 [4:08:57<09:49,  5.72s/it][A
Epochs 0/1. Running Loss:    0.1503:  95%|█████████▌| 2085/2188 [4:08:58<09:49,  5.72s/it][A
Epochs 0/1. Running Loss:    0.1503:  95%|█████████▌| 2086/2188 [4:09:03<09:46,  5.75s/it][A
Epochs 0/1. Running Loss:    0.3600:  95%|█████████▌| 2086/2188 [4:09:04<09:46,  5.75s/it][A
Epochs 0/1. Running Loss:    0.3600:  95%|█████████▌| 2087/2188 [4:09:09<09:40,  5.75s/it][A
Epochs 0/1. Running Loss:    0.3722:  95%|████████

(2188, 0.4587006285700754)

In [22]:
print(f'Training set consists of {len(subsets["train"])} samples')
print(f'Dev set consists of {len(subsets["dev"])} samples')

Training set consists of 17502 samples
Dev set consists of 2239 samples


In [23]:
result, model_outputs, wrong_predictions = model.eval_model(subsets['dev'])

100%|██████████| 2239/2239 [00:39<00:00, 57.32it/s]
Running Evaluation: 100%|██████████| 280/280 [11:02<00:00,  2.37s/it]


In [39]:
predictions = [np.argmax(output) for output in model_outputs]

In [40]:
print(classification_report(subsets['dev']['label'], predictions))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00       276
           1       0.88      1.00      0.93      1963

    accuracy                           0.88      2239
   macro avg       0.44      0.50      0.47      2239
weighted avg       0.77      0.88      0.82      2239

