In [1]:
# Noah-Manuel Michael
# Created: 11.05.2023
# Last updated: 13.06.2023
# https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/text_classification.ipynb
# Fine-tune transformer models for word order error detection

import pandas as pd
from transformers import BertTokenizer, BertForSequenceClassification
from transformers import Trainer, TrainingArguments
from utils_detection import SequenceClassificationDataset


def fine_tune_bertje_no_punc_for_detection():
    """

    :return:
    """
    df_train = pd.read_csv('train_shuffled_random_all_and_verbs_sampled_transformer.tsv', sep='\t', header=0, encoding='utf-8')
    df_dev = pd.read_csv('dev_shuffled_random_all_and_verbs_sampled_transformer.tsv', sep='\t', header=0, encoding='utf-8')
    
    # num_labels = 2
    tokenizer = BertTokenizer.from_pretrained('GroNLP/bert-base-dutch-cased', do_lower_case=True)
    model = BertForSequenceClassification.from_pretrained('GroNLP/bert-base-dutch-cased')

    train_texts = [s for s in df_train['no_punc']] + \
                  [s for s in df_train['scrambled_no_punc']]
    train_labels = [1 for _ in range(int(len(train_texts)/2))] + \
                   [0 for _ in range(int(len(train_texts)/2))]
    val_texts = [s for s in df_dev['no_punc']] + \
                [s for s in df_dev['scrambled_no_punc']]
    val_labels = [1 for _ in range(int(len(val_texts)/2))] + \
                 [0 for _ in range(int(len(val_texts)/2))]

    train_dataset = SequenceClassificationDataset(train_texts, train_labels, tokenizer)
    val_dataset = SequenceClassificationDataset(val_texts, val_labels, tokenizer)

    training_args = TrainingArguments(output_dir='results_bertje_detection_no_punc',
                                      num_train_epochs=3,
                                      per_device_train_batch_size=128,
                                      per_device_eval_batch_size=128,
                                      warmup_steps=500,
                                      weight_decay=0.01,
                                      save_strategy='epoch',
                                      evaluation_strategy='epoch',
                                      load_best_model_at_end=True,
                                      report_to=[])

    trainer = Trainer(model=model,
                      args=training_args,
                      train_dataset=train_dataset,
                      eval_dataset=val_dataset)

    trainer.train()

    trainer.save_model('./finetuned_bertje_sequence_classification_no_punc')



if __name__ == '__main__':
    fine_tune_bertje_no_punc_for_detection()


  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at GroNLP/bert-base-dutch-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initial

Epoch,Training Loss,Validation Loss
1,0.0198,0.018167
2,0.0094,0.017558


IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter serve