# Fine-tune token classifier for social group mention detection and extraction

*author*: **Hauke Licht** (hauke.licht@uibk.ac.at)

<br>

In this notebook, we fine-tune a transformer encoder for social group mention detection through supervised token classificier using annotations from

> Licht H, Sczepanski R. Detecting Group Mentions in Political Rhetoric: A Supervised Learning Approach. *forthcoming*. *The British Journal of Political Science*. doi:[10.31219/osf.io/ufb96](https://doi.org/10.31219/osf.io/ufb96)

In particular, we will use the labels we have obtained from two coders group mention annotations of of UK party manifestos: https://github.com/haukelicht/group_mention_detection/blob/main/replication/data/annotation/labeled/uk-manifestos_all_labeled.jsonl (but see also the other two data files in the repository)

<!-- <a target="_blank" href="https://colab.research.google.com/github/haukelicht/comptext25_task_type_toolkit_tutorial/blob/main/span_extraction/token_classifier_finetuning.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>  -->

## 1. Setup

### Setup Colab (if needed)

In [None]:
# check if on Colab
ON_COLAB = True
try:
  from google import colab
except:
  ON_COLAB = False

In [None]:
# install required packages
if ON_COLAB:
    !pip install -q nltk==3.9.1 accelerate~=1.5.0 datasets==3.5.0 tokenizers==0.21.1 transformers~=4.51.3 scikit-learn==1.6.1 seqeval==1.2.2
    !pip install -q --upgrade --force-reinstall --no-deps git+https://github.com/haukelicht/soft-seqeval.git@main

In [None]:
# download the data
if ON_COLAB:
    !mkdir -p data
    GITHUB_PATH="https://raw.githubusercontent.com/haukelicht/group_mention_detection/refs/heads/main/replication/data/annotation/labeled/"
    !wget -O data/uk-manifestos_all_labeled.jsonl -q $GITHUB_PATH/uk-manifestos_all_labeled.jsonl

### Define the arguments

*Note:* 
I'm using `types.SimpleNamespace` here so that the object behaves similar to the output of `argparse.ArgumentParser().parse_args()`.
This way, it's very easy to convert this notebook to an executable python script.

In [53]:
from pathlib import Path
from types import SimpleNamespace

args = SimpleNamespace()

args.data_file = Path('data/uk-manifestos_all_labeled.jsonl') if ON_COLAB else Path('../replication/data/annotation/labeled/uk-manifestos_all_labeled.jsonl')
args.dev_size = 0.1
args.test_size = 0.2

# model name in huggingface model hub 
args.model_name = "answerdotai/ModernBERT-base"  # "answerdotai/ModernBERT-large"

# path where to save temporary files and the final model (if desired)
args.output_path = Path('results/finetuning')

## hyperparameters
args.epochs=10
args.learning_rate=4e-5
args.train_batch_size=16
args.gradient_accumulation_steps=2
args.eval_batch_size=32
args.weight_decay=0.3

## early stopping
args.early_stopping = True
args.metric = 'seqeval-social group_f1' # metric used for early stopping and to select the best model among saved checkpoints after stopping
args.early_stopping_patience = 3
args.early_stopping_threshold = 0.03

# for rerpoducibility
args.seed = 42

args.save_finetuned_model = False

### Load required libraries

In [6]:
import shutil
import numpy as np
import pandas as pd
import json

from datasets import Dataset, DatasetDict

import torch
import transformers
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification,
    EarlyStoppingCallback,
    set_seed,
)
set_seed(args.seed)
# uncomment the next two lines if you want to suppress the logging output
# from transformers.utils import logging
# logging.set_verbosity_error()

from soft_seqeval.metrics import compute_sequence_metrics

### Define custom helper functions

In [7]:
import json
from typing import Dict, Any, List, Union
def read_jsonl(path: Union[Path, str], replace_newlines: bool = False) -> List[Dict[str, Any]]:
    """
    Read jsonlines from `path`, supporting .zip and .gz files.
    """
    # handle regular files
    with open(path) as infile:
        if not replace_newlines:
            return [json.loads(line) for line in infile if line]
        else:
            return [json.loads(line.replace("\\n", " ")) for line in infile if line]

In [8]:
# see also: https://github.com/haukelicht/group_mention_detection/blob/main/replication/code/utils/classification.py
def tokenize_and_align_sequence_labels(examples, tokenizer, **kwargs) -> Dict:
    # source: simplied from  https://github.com/huggingface/transformers/blob/730a440734e1fb47c903c17e3231dac18e3e5fd6/examples/pytorch/token-classification/run_ner.py#L442
    tokenized_inputs = tokenizer(examples['tokens'], is_split_into_words=True, **kwargs)

    labels = []
    for i, label in enumerate(examples['labels']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs['labels'] = labels
    return tokenized_inputs

In [9]:
from datasets import Dataset
# see also: https://github.com/haukelicht/group_mention_detection/blob/main/replication/code/utils/classification.py
def create_token_classification_dataset(
    data: List[Dict], 
    tokens_field: str='tokens',
    labels_field: Union[None, str]='labels'
):
    dataset = Dataset.from_list(data)
    if tokens_field != 'tokens':
        dataset = dataset.rename_column(tokens_field, 'tokens')
    if labels_field is not None and labels_field != 'labels':
        dataset = dataset.rename_column(labels_field, 'labels')
    required = ['tokens'] if labels_field is None else ['tokens', 'labels']
    rm = [c for c in dataset.column_names if c not in required]
    if len(rm) > 0:
        dataset = dataset.remove_columns(rm)
    return dataset

## 2. Load and prepare the data


Our data files at https://github.com/haukelicht/group_mention_detection/blob/main/replication/data/annotation/labeled/ contain the token-level labels aggregated from our two coders' group mention annotations.

Each file is a [JSONlines file](https://jsonlines.org/), that is, a text file with one JSON dictionary per line.

An exemplary line looks like this:

```json
{
 "id": "829ac29cd9304a66265e3ea830a505e3",
 "text": "Seit 150 Jahren machen wir Politik für eine bessere Gesellschaft .",
 "tokens": [
  "Seit",
  "150",
  "Jahren",
  "machen",
  "wir",
  "Politik",
  "für",
  "eine",
  "bessere",
  "Gesellschaft",
  "."],
 "annotations": {"emarie": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]},
 "labels": {"BSCModel": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]},
 "metadata": {"sentence_id": "41320.000.2013.1.1-2520-1",
  "split_": "smarie",
  "job": "group-mentions-annotation-de-manifestos-round-01"}
}
```

Note that each dictionary records 

- the text in pre-tokenized format,
- annotations and labels at the _token_ level (a dictionary of lists, one per annotator),

In particular, 

- Annotations are in the field `"annotations"` and map annotator IDs to their token-level annotations.
- **Labels** are in the field `"labels"`, and `"BSCModel"` records the Bayesian Sequence Combination (BSC) model-based aggregate labels.


In [72]:
def parse_record(d):
    return {'id': d['id'], 'tokens': d['tokens'], 'labels': d['labels']['BSCModel']}

data = read_jsonl(args.data_file)

data = [parse_record(d) for d in data]

In [73]:
doc = data[5]
for t, l in zip(doc['tokens'], doc['labels']):
    print(f"{t} ==> {l}")

The ==> 8
MoD ==> 3
should ==> 0
provide ==> 0
much ==> 0
better ==> 0
support ==> 0
to ==> 0
next ==> 6
of ==> 1
kin ==> 1
and ==> 1
bereaved ==> 1
families ==> 1
in ==> 1
the ==> 1
event ==> 1
of ==> 1
a ==> 1
loss ==> 1
of ==> 1
a ==> 1
serving ==> 1
relative ==> 1
. ==> 0


In [74]:
# show available label IDs
set(l for d in data for l in d['labels'])
# NOTE: implies 5 types (0 reserved for outside token and otherwise 2 IDs (B and I) per group type)

{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10}

The records in `"labels"` are numeric label IDs that map onto our **group types**:

- social group
- political group
- political institution
- organization, public institution, or collective actor
- implicit social group reference

Here is how to convert them to text labels:

In [None]:
# get list of entity types
types = [
  "social group",
  "political group",
  "political institution",
  "organization, public institution, or collective actor",
  "implicit social group reference",
]
# convert to IOB2 scheme
scheme = ['O'] + ['I-'+t for t in types] + ['B-'+t for t in types]
# map label type indicators to label IDs
label2id = {l: i for i, l in enumerate(scheme)}
# and vice versa
id2label = {i: l for i, l in enumerate(scheme)}
NUM_LABELS = len(label2id)

label2id
# NOTE: the span-level annotations will be converted to token-level annotations using the IOB2 scheme.append
#       This means that 
#        - a word that are not part of any entity will be labeled as "O",
#        - a word at the beginning of a span will be labeled as "B-<entity_type>", and 
#        - a word inside a span will be labeled as "I-<entity_type>"

{'O': 0,
 'I-social group': 1,
 'I-political group': 2,
 'I-political institution': 3,
 'I-organization, public institution, or collective actor': 4,
 'I-implicit social group reference': 5,
 'B-social group': 6,
 'B-political group': 7,
 'B-political institution': 8,
 'B-organization, public institution, or collective actor': 9,
 'B-implicit social group reference': 10}

### Split the data

In [28]:
from typing import Optional, Union
from sklearn.model_selection import train_test_split
def split_data(
        data: List[Dict],
        test_size: Union[None, float, int]=0.2,
        dev_size: Union[None, float, int]=0.2,
        stratify_by: Optional[Union[str, List[str]]]=None,
        seed: int=42,
        return_dict: bool=False
    ):
    """Split a cropus into training, development, and test sets.

    Args:

    df: List[Dict]
        The corpus to split. Must be a list of dictionaries.
    dev_size: float
        The proportion of the data to include in the development set.
    test_size: float
        The proportion of the data to include in the test set.
    stratify_by: str or list of str, optional
        Metadata field(s) to use for stratified splitting. If a single field is 
        provided, the data will be stratified by the values in that field in the metadata. 
        If multiple columns are provided, the data will be stratified by 
        the unique combinations of values of these fields in the metadata.
    seed: int
        Random seed for reproducibility.
    return_dict: bool
        Whether to return the splits as a dictionary.
    """
    n = len(data)
    
    if stratify_by:
        assert all('metadata' in doc for doc in data), "Stratification requires 'metadata' field in each document's dictionary"
        if isinstance(stratify_by, str):
            stratify_by = [stratify_by]
        for field in stratify_by:
            assert all(field in doc['metadata'] for doc in data), f"Field '{field}' not found in 'metadata' of all documents"
        # create a grouping indicator based on the stratification columns
        strata = ['__'.join([str(doc['metadata'][field]) for field in stratify_by]) for doc in data]
    else:
        strata = None
        
    idxs = list(range(n))
    tmp, test_idxs = train_test_split(idxs, test_size=test_size, random_state=seed, stratify=strata)# if test_size > 0 else (idxs, [])
    strata = [strata[i] for i in test_idxs] if stratify_by else None
    train_idxs, dev_idxs = train_test_split(tmp, test_size=dev_size, random_state=seed, stratify=strata)# if dev_size > 0 else (idxs, [])

    train, dev, test = [data[i] for i in train_idxs], [data[i] for i in dev_idxs], [data[i] for i in test_idxs]
    
    if return_dict:
        return {'train': train, 'dev': dev, 'test': test}
    else:
        return train, dev, test

In [31]:
data_splits = split_data(data, test_size=args.test_size, dev_size=args.dev_size, seed=args.seed, return_dict=True)

In [None]:
# load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True, add_prefix_space=True)
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

# apply the custom function defined above to set subword tokens' labels to -100
# this is necessary because the tokenization may split a word into multiple subwords
datasets = DatasetDict({split: create_token_classification_dataset(data) for split, data in data_splits.items()})
datasets = datasets.map(lambda example: tokenize_and_align_sequence_labels(example, tokenizer=tokenizer), batched=True)

Map: 100%|██████████| 6174/6174 [00:00<00:00, 12340.91 examples/s]
Map: 100%|██████████| 686/686 [00:00<00:00, 13531.16 examples/s]
Map: 100%|██████████| 1716/1716 [00:00<00:00, 14177.01 examples/s]


In [36]:
datasets.num_rows

{'train': 6174, 'dev': 686, 'test': 1716}

In [37]:
# uncomment to show example
example = datasets['train'][2]
for t, l in zip(example['input_ids'], example['labels']):
    if t == tokenizer.pad_token_id:
        break
    print(l, '\t', repr(tokenizer.decode(t)))

-100 	 '[CLS]'
0 	 ' We'
0 	 ' shall'
0 	 ' end'
0 	 ' the'
0 	 ' practice'
0 	 ' of'
0 	 ' allowing'
0 	 ' permanent'
0 	 ' settlement'
0 	 ' for'
6 	 ' those'
1 	 ' who'
1 	 ' come'
1 	 ' here'
1 	 ' for'
1 	 ' a'
1 	 ' temporary'
1 	 ' stay'
0 	 '.'
-100 	 '[SEP]'


In [None]:
# NOTE: after tokenization, text tokens are represented with their token IDs
#        so we can remove them from the dataset (need to load these to the GPU)
datasets = datasets.remove_columns(['tokens']) 

## Prepare fine-tuning

In [38]:
# NOTE: the `model_init` function is used by the Trainer to initialize the model
#   and is called each time before training starts.
#  So we define it here to load the model from the Huggingface model hub
#   and set the number of labels to the number of unique labels in the dataset
#   and the label2id and id2label mappings
def model_init():
    config = AutoConfig.from_pretrained(args.model_name)
    config.num_labels = NUM_LABELS
    config.label2id = label2id
    config.id2label = id2label
    return AutoModelForTokenClassification.from_pretrained(args.model_name, config=config, device_map='auto')

In [39]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    # convert predictions and labels to list of lists of ints
    predictions = predictions.astype(int).tolist()
    labels = labels.astype(int).tolist()
    return compute_sequence_metrics(y_true=labels, y_pred=predictions, id2label=id2label, flatten_output=True)
# NOTE: the `compute_metrics` function is used by the Trainer to compute the evaluation metrics 

In [40]:
# NOTE: at the beginning of the script, we have defined args.metric as the metric to be used for early stopping
#       and model selection among saved checkpoints after stopping
#       This metric must be available in the output of our `compute_metrics` function defined above
#       So let's check this

ex = ['O', 'B-social group', 'I-social group', 'O']
scores = compute_sequence_metrics([ex], [ex], id2label, flatten_output=True)
if args.metric not in scores.keys():
    raise ValueError(f"Invalid metric: {args.metric}, valid metrics are: {', '.join(scores.keys())}")

In [52]:
# look at metrics computed by the `compute_sequence_metrics` function provided by my soft-seqeval package
# NOTE: the pattern is <scheme>_<metric>_<entity_type> where 
#       - <scheme> is the scheme is 
#           - "seqeval": strict seqeval metric as implemented in https://github.com/chakki-works/seqeval
#           - "softseqeval": soft-seqeval metrics implemented as described here https://github.com/haukelicht/soft-seqeval/blob/main/notebooks/available_metrics.ipynb
#           - "wordlevel": evaluation at word level
#           - "doclevel": evaluation at document (i.e., sentence) level
#       - <metric> is the metric used for evaluation (f1, precision, or recall)
#       - <entity_type> is the entity type (here, "social group" or "other") or "macro" or "micro" for macro or micro average over all entity types
metrics_overview = pd.DataFrame(list(scores.keys()), columns=['key'])
metrics_overview[['scheme', 'metric']] = metrics_overview.key.str.split('-', n=2, expand=True)
metrics_overview[['type', 'metric']] = metrics_overview.metric.str.split('_', n=2, expand=True)
metrics_overview

Unnamed: 0,key,scheme,metric,type
0,seqeval-macro_f1,seqeval,f1,macro
1,seqeval-macro_precision,seqeval,precision,macro
2,seqeval-macro_recall,seqeval,recall,macro
3,seqeval-micro_f1,seqeval,f1,micro
4,seqeval-micro_precision,seqeval,precision,micro
5,seqeval-micro_recall,seqeval,recall,micro
6,seqeval-social group_f1,seqeval,f1,social group
7,seqeval-social group_precision,seqeval,precision,social group
8,seqeval-social group_recall,seqeval,recall,social group
9,softseqeval-macro_f1,softseqeval,f1,macro


**Note:** Because we are most interested in detection (exact) occurrences of social group mentions, we will use `seqeval-social group_f1`


### Define the training arguments

In [54]:
out_dir = args.output_path
checkpoints_dir = out_dir / 'checkpoints'
logs_dir = out_dir / 'logs'

training_args = TrainingArguments(
    
    # hyperparameters
    num_train_epochs=args.epochs,
    learning_rate=args.learning_rate,
    per_device_train_batch_size=args.train_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    per_device_eval_batch_size=args.eval_batch_size,
    weight_decay=args.weight_decay,
    optim='adamw_torch',
    
    # when to evaluate
    eval_strategy='epoch',
    # how to select "best" model
    do_eval=bool('dev' in datasets),
    metric_for_best_model=args.metric,
    load_best_model_at_end=True,
    # when to save
    save_strategy='epoch',
    save_total_limit=2 if 'dev' in datasets else None, # don't save all model checkpoints
    # where to store results
    output_dir=checkpoints_dir,
    overwrite_output_dir=True,
    
    # logging
    logging_dir=logs_dir,
    logging_strategy='epoch',
    report_to='none',
    
    # reproducibility
    seed=args.seed,
    data_seed=args.seed,
    full_determinism=True
)


# build callbacks
callbacks = []
if args.early_stopping:
    if 'dev' not in datasets:
        raise ValueError('Early stopping requires a dev data set')
    callbacks.append(EarlyStoppingCallback(early_stopping_patience=args.early_stopping_patience, early_stopping_threshold=args.early_stopping_threshold))


## Fine-tuning

### Create the trainer

In [55]:
trainer = Trainer(
    model_init=model_init,
    args=training_args,
    train_dataset=datasets['train'].select(range(0, 1000)),
    eval_dataset=datasets['dev'],
    processing_class=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer),
    compute_metrics=compute_metrics,
    callbacks=callbacks
)

Some weights of ModernBertForTokenClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Fine-tune

In [56]:
print('Training ...')
train_hist = trainer.train()

Training ...


Some weights of ModernBertForTokenClassification were not initialized from the model checkpoint at answerdotai/ModernBERT-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Seqeval-macro F1,Seqeval-macro Precision,Seqeval-macro Recall,Seqeval-micro F1,Seqeval-micro Precision,Seqeval-micro Recall,Seqeval-political institution F1,Seqeval-political institution Precision,Seqeval-political institution Recall,"Seqeval-organization, public institution, or collective actor F1","Seqeval-organization, public institution, or collective actor Precision","Seqeval-organization, public institution, or collective actor Recall",Seqeval-political group F1,Seqeval-political group Precision,Seqeval-political group Recall,Seqeval-implicit social group reference F1,Seqeval-implicit social group reference Precision,Seqeval-implicit social group reference Recall,Seqeval-social group F1,Seqeval-social group Precision,Seqeval-social group Recall,Softseqeval-macro F1,Softseqeval-macro Precision,Softseqeval-macro Recall,Softseqeval-micro F1,Softseqeval-micro Precision,Softseqeval-micro Recall,Softseqeval-political institution F1,Softseqeval-political institution Precision,Softseqeval-political institution Recall,"Softseqeval-organization, public institution, or collective actor F1","Softseqeval-organization, public institution, or collective actor Precision","Softseqeval-organization, public institution, or collective actor Recall",Softseqeval-political group F1,Softseqeval-political group Precision,Softseqeval-political group Recall,Softseqeval-implicit social group reference F1,Softseqeval-implicit social group reference Precision,Softseqeval-implicit social group reference Recall,Softseqeval-social group F1,Softseqeval-social group Precision,Softseqeval-social group Recall,Doclevel-micro Precision,Doclevel-micro Recall,Doclevel-micro F1,Doclevel-political institution Precision,Doclevel-political institution Recall,Doclevel-political institution F1,"Doclevel-organization, public institution, or collective actor Precision","Doclevel-organization, public institution, or collective actor Recall","Doclevel-organization, public institution, or collective actor F1",Doclevel-political group Precision,Doclevel-political group Recall,Doclevel-political group F1,Doclevel-implicit social group reference Precision,Doclevel-implicit social group reference Recall,Doclevel-implicit social group reference F1,Doclevel-social group Precision,Doclevel-social group Recall,Doclevel-social group F1,Wordlevel-accuracy,Wordlevel-macro F1,Wordlevel-macro Precision,Wordlevel-macro Recall,Wordlevel-o F1,Wordlevel-o Precision,Wordlevel-o Recall,Wordlevel-political institution F1,Wordlevel-political institution Precision,Wordlevel-political institution Recall,"Wordlevel-organization, public institution, or collective actor F1","Wordlevel-organization, public institution, or collective actor Precision","Wordlevel-organization, public institution, or collective actor Recall",Wordlevel-political group F1,Wordlevel-political group Precision,Wordlevel-political group Recall,Wordlevel-implicit social group reference F1,Wordlevel-implicit social group reference Precision,Wordlevel-implicit social group reference Recall,Wordlevel-social group F1,Wordlevel-social group Precision,Wordlevel-social group Recall
1,0.6736,0.370977,0.202423,0.193425,0.224246,0.221287,0.206974,0.237726,0.161369,0.141631,0.1875,0.117647,0.156863,0.094118,0.492754,0.481132,0.50495,0.0,0.0,0.0,0.240343,0.1875,0.334661,0.278971,0.328959,0.262628,0.529547,0.583547,0.521703,0.245629,0.309075,0.21851,0.167319,0.235577,0.142353,0.611528,0.697222,0.573889,0.0,0.0,0.0,0.370379,0.402922,0.37839,0.817784,0.817784,0.817784,0.801749,0.801749,0.801749,0.832362,0.832362,0.832362,0.959184,0.959184,0.959184,0.893586,0.893586,0.893586,0.832362,0.832362,0.832362,0.906742,0.488941,0.552843,0.465322,0.958876,0.947125,0.970923,0.428769,0.487421,0.382716,0.297456,0.580153,0.2,0.698305,0.811024,0.613095,0.0,0.0,0.0,0.550243,0.491337,0.625197
2,0.3297,0.252452,0.345489,0.331207,0.364713,0.397135,0.400262,0.394057,0.234177,0.264286,0.210227,0.342391,0.318182,0.370588,0.679245,0.648649,0.712871,0.0,0.0,0.0,0.471631,0.42492,0.52988,0.424062,0.48143,0.400841,0.626779,0.70016,0.597015,0.377304,0.493367,0.33038,0.420441,0.479985,0.39543,0.758395,0.811782,0.732423,0.0,0.0,0.0,0.564171,0.622015,0.54597,0.855685,0.855685,0.855685,0.887755,0.887755,0.887755,0.889213,0.889213,0.889213,0.970845,0.970845,0.970845,0.893586,0.893586,0.893586,0.873178,0.873178,0.873178,0.933278,0.597682,0.671275,0.551316,0.96871,0.948784,0.989491,0.513912,0.762136,0.387654,0.58963,0.674576,0.523684,0.834891,0.875817,0.797619,0.0,0.0,0.0,0.678947,0.766337,0.609449
3,0.1866,0.217188,0.480093,0.491782,0.493756,0.48814,0.472222,0.505168,0.282051,0.323529,0.25,0.451282,0.4,0.517647,0.731481,0.686957,0.782178,0.37931,0.55,0.289474,0.556338,0.498423,0.629482,0.567324,0.619501,0.54638,0.692779,0.751088,0.672979,0.441944,0.546875,0.396048,0.54147,0.599182,0.516409,0.857237,0.884091,0.845152,0.339286,0.380952,0.319444,0.656684,0.686405,0.654845,0.88484,0.88484,0.88484,0.906706,0.906706,0.906706,0.900875,0.900875,0.900875,0.983965,0.983965,0.983965,0.930029,0.930029,0.930029,0.902332,0.902332,0.902332,0.945257,0.721726,0.840633,0.669425,0.973601,0.961121,0.98641,0.543624,0.848168,0.4,0.698499,0.725212,0.673684,0.903614,0.914634,0.892857,0.435897,0.809524,0.298246,0.77512,0.785137,0.765354
4,0.1045,0.214665,0.539938,0.544503,0.548882,0.541485,0.523522,0.560724,0.453039,0.44086,0.465909,0.459384,0.438503,0.482353,0.790476,0.761468,0.821782,0.409836,0.543478,0.328947,0.586957,0.538206,0.645418,0.616377,0.671525,0.59428,0.724714,0.789786,0.700617,0.586054,0.676375,0.550381,0.582319,0.647712,0.555859,0.834899,0.858108,0.824324,0.378431,0.423529,0.356863,0.700179,0.751902,0.683974,0.897959,0.897959,0.897959,0.930029,0.930029,0.930029,0.924198,0.924198,0.924198,0.98105,0.98105,0.98105,0.934402,0.934402,0.934402,0.905248,0.905248,0.905248,0.950063,0.754415,0.855275,0.693864,0.975496,0.962115,0.989254,0.684814,0.8157,0.590123,0.706745,0.798013,0.634211,0.900302,0.91411,0.886905,0.487805,0.8,0.350877,0.771331,0.841713,0.711811
5,0.0522,0.22396,0.560277,0.531309,0.598984,0.55648,0.515419,0.604651,0.467192,0.434146,0.505682,0.498728,0.439462,0.576471,0.796296,0.747826,0.851485,0.446043,0.492063,0.407895,0.593128,0.543046,0.653386,0.603171,0.638227,0.592774,0.737793,0.776968,0.729087,0.606712,0.669458,0.584977,0.558937,0.598981,0.544708,0.811638,0.819684,0.813937,0.365979,0.393471,0.353093,0.672591,0.709539,0.667157,0.905248,0.905248,0.905248,0.931487,0.931487,0.931487,0.906706,0.906706,0.906706,0.979592,0.979592,0.979592,0.925656,0.925656,0.925656,0.906706,0.906706,0.906706,0.951595,0.765325,0.803002,0.738969,0.977431,0.971141,0.983802,0.733154,0.807122,0.671605,0.721268,0.724138,0.718421,0.895954,0.870787,0.922619,0.491979,0.630137,0.403509,0.772162,0.814685,0.733858
6,0.0214,0.245494,0.565954,0.532491,0.607334,0.565321,0.523077,0.614987,0.439276,0.402844,0.482955,0.501333,0.458537,0.552941,0.801843,0.75,0.861386,0.455172,0.478261,0.434211,0.632143,0.572816,0.705179,0.612625,0.653888,0.597,0.74243,0.782047,0.730846,0.595874,0.657716,0.570848,0.586339,0.638117,0.565333,0.815004,0.835489,0.806034,0.388333,0.421667,0.3725,0.677575,0.716453,0.670283,0.902332,0.902332,0.902332,0.930029,0.930029,0.930029,0.925656,0.925656,0.925656,0.982507,0.982507,0.982507,0.928571,0.928571,0.928571,0.899417,0.899417,0.899417,0.952013,0.770888,0.809325,0.741692,0.977586,0.971371,0.983881,0.717131,0.775862,0.666667,0.720994,0.758721,0.686842,0.891496,0.878613,0.904762,0.534031,0.662338,0.447368,0.784091,0.809045,0.76063
7,0.0099,0.272648,0.586422,0.56894,0.613215,0.583384,0.549714,0.621447,0.486772,0.455446,0.522727,0.515306,0.454955,0.594118,0.851675,0.824074,0.881188,0.447761,0.517241,0.394737,0.630597,0.592982,0.673307,0.630416,0.663421,0.619705,0.752157,0.79063,0.742455,0.617539,0.670635,0.599486,0.575159,0.611799,0.562632,0.872936,0.876911,0.875382,0.405185,0.438889,0.388889,0.681258,0.71887,0.672134,0.913994,0.913994,0.913994,0.93586,0.93586,0.93586,0.91691,0.91691,0.91691,0.983965,0.983965,0.983965,0.937318,0.937318,0.937318,0.912536,0.912536,0.912536,0.953336,0.778307,0.829132,0.744785,0.978168,0.970522,0.985936,0.736424,0.794286,0.68642,0.73107,0.725389,0.736842,0.920821,0.907514,0.934524,0.533333,0.727273,0.421053,0.770026,0.84981,0.703937
8,0.0059,0.264297,0.600778,0.577515,0.631829,0.603636,0.568493,0.643411,0.506667,0.477387,0.539773,0.52987,0.474419,0.6,0.835681,0.794643,0.881188,0.463768,0.516129,0.421053,0.667904,0.625,0.717131,0.634413,0.669257,0.623573,0.752148,0.789876,0.742711,0.621689,0.680952,0.602899,0.5836,0.615471,0.573083,0.854167,0.861012,0.85506,0.419355,0.458781,0.401434,0.693256,0.730066,0.685387,0.905248,0.905248,0.905248,0.934402,0.934402,0.934402,0.915452,0.915452,0.915452,0.98688,0.98688,0.98688,0.937318,0.937318,0.937318,0.913994,0.913994,0.913994,0.954033,0.784354,0.827438,0.7552,0.978217,0.971847,0.984671,0.737683,0.800578,0.683951,0.742105,0.742105,0.742105,0.910145,0.887006,0.934524,0.554348,0.728571,0.447368,0.783626,0.83452,0.738583
9,0.0026,0.274412,0.593309,0.57134,0.62361,0.595281,0.559727,0.635659,0.506596,0.472906,0.545455,0.512821,0.454545,0.588235,0.834123,0.8,0.871287,0.452555,0.508197,0.407895,0.660448,0.621053,0.705179,0.633426,0.668346,0.622685,0.753789,0.790512,0.745213,0.624629,0.679245,0.606645,0.575706,0.611726,0.563264,0.857572,0.868769,0.856006,0.419355,0.456989,0.405018,0.689867,0.725,0.682492,0.912536,0.912536,0.912536,0.934402,0.934402,0.934402,0.912536,0.912536,0.912536,0.98688,0.98688,0.98688,0.938776,0.938776,0.938776,0.915452,0.915452,0.915452,0.954033,0.783781,0.826432,0.754587,0.978369,0.972224,0.984592,0.741425,0.796034,0.693827,0.740157,0.73822,0.742105,0.906433,0.890805,0.922619,0.554348,0.728571,0.447368,0.781955,0.83274,0.737008


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


## Evaluate

In [57]:
# apply the best model loaded after finishing training to the test set
print('Evaluating ...')
test_res = trainer.evaluate(datasets['test'], metric_key_prefix='test')

Evaluating ...




early stopping required metric_for_best_model, but did not find eval_seqeval-social group_f1 so early stopping is disabled


In [58]:
# create a more nice-to-loook-at output
out = pd.DataFrame(test_res, index=['value']).T
out = out.reset_index().rename(columns={'index': 'cat'})
out[['set', 'scheme', 'metric', 'misc']] = out.cat.str.split('_', expand=True)
out = out[out.misc.isnull()]
out = out[out.metric.notnull()]
out[['scheme', 'type']] = out.scheme.str.split('-', expand=True)
out = out.drop(columns=['set', 'cat', 'misc'])
out = out[['scheme', 'type', 'metric', 'value']]
out = out.pivot(index=['type', 'scheme', ], columns='metric', values='value')
keys = [
    (typ, scheme)
    for typ in types
    for scheme in ['seqeval', 'softseqeval', 'wordlevel', 'doclevel']
]
out.loc[keys, :]

Unnamed: 0_level_0,metric,f1,precision,recall
type,scheme,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
social group,seqeval,0.580122,0.52963,0.641256
social group,softseqeval,0.660909,0.705151,0.652858
social group,wordlevel,0.731133,0.801009,0.672471
social group,doclevel,0.907925,0.907925,0.907925
political group,seqeval,0.77187,0.728155,0.821168
political group,softseqeval,0.843238,0.857103,0.840604
political group,wordlevel,0.900648,0.891026,0.91048
political group,doclevel,0.985431,0.985431,0.985431
political institution,seqeval,0.486784,0.442886,0.540342
political institution,softseqeval,0.567123,0.615805,0.548582


## Inference

In [60]:
from datasets import Dataset

from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset

from tqdm import tqdm

We'll use the transformer's `pipeline` for inference (i.e., predicting spans in unlabeled data).

Specifically, we use the **NER** (named entity recognition) task and pass the fine-tuned model from the trainer.

In [61]:
extractor = pipeline(task='ner', model=trainer.model, tokenizer=tokenizer, batch_size=32, aggregation_strategy='simple')

Device set to use mps


In [67]:
docs = read_jsonl(args.data_file)

# take fits 500 sentences just for illustrative purposes
docs = docs[:500]

# keep only text and id fields
docs = [{f: doc[f] for f in ['text', 'id']} for doc in docs]

# for batch inference (see https://huggingface.co/docs/transformers/pipeline_tutorial#batch-inference)
kd = KeyDataset(Dataset.from_list(docs), 'text')

In [68]:
# apply the extractor to the dataset
pred_ents = [p for p in tqdm(extractor(kd), total=len(docs))]

100%|██████████| 500/500 [00:05<00:00, 95.94it/s] 


For each text in the list of texts taken from `docs`, we get a list of dictionaries, here called `pred_ents`.

Each item in `pred_ents` is a dictionary with the following fields:

- start: character start index of the entity in the text
- end: character end index of the entity in the text
- score: confidence score of the prediction
- word: the text of the entity
- entity_group: the entity type (e.g., 'social group')


Let's use convert these annotations into one `Entities` instance and create a new `LabeledSequence` instance from this information for each text:  

In [69]:
from soft_seqeval.classes import Entity, Entities
from soft_seqeval.classes import LabeledSequence
from copy import deepcopy

def pipeline_output_to_entities(pred) -> Entities:
    """Take output from the NER pipeline and convert to Entities instance"""
    ents = []
    for ent in pred:
        ent = deepcopy(ent)
        if ent['word'][0] == ' ':
            ent['start'] += 1
        if ent['word'][-1] == ' ':
            ent['end'] -= 1
        ents.append(Entity(ent['start'], ent['end'], ent['entity_group']))
    return Entities(ents)

# iterate over the documents and predicted annotations to create a list of LabeledSequence instances
preds = [
    LabeledSequence(text=doc['text'], entities=pipeline_output_to_entities(pred), id=doc['id'], lang='english')
    for doc, pred in zip(docs, pred_ents)
]

In [70]:
# look at first 10 examples
preds[:10]


[[1m89e12104790ce027289473f8814f710d[0m: "We must be not be bound by any freedom of movement obligation , and we must be free to set and meet our own annual migration targets .",
 [1m48b1c6ba33bb5e538c420148ec993090[0m: "This would provide a boost of over £100 million , which we believe will provide important new opportunities for [43m[1mproduction companies[0m[43m [organization, public institution, or collective actor][49m and [43m[1mthe creative sector[0m[43m [organization, public institution, or collective actor][49m in Scotland .",
 [1mef34613868b210a362e5bf0eee91fd35[0m: "Neither [43m[1mLabour[0m[43m [political group][49m nor [43m[1mConservatives[0m[43m [political group][49m are interested in changing our broken system , because it works to keep them in power .",
 [1m7399dc908f28099df24ecc0092cff235[0m: "P[43m[1mrisons[0m[43m [organization, public institution, or collective actor][49m should be places of rehabilitation : when [43m[1mpeople[0m[4

## Finally

#### Delete intermediate checkpoints and log files

In [71]:
# finally: clean up
if checkpoints_dir.exists():
    shutil.rmtree(checkpoints_dir)
if logs_dir.exists():
    shutil.rmtree(logs_dir)

#### Save the best model (if desired)

In [None]:
if args.save_finetuned_model:
    trainer.save_model(out_dir)
    tokenizer.save_pretrained(out_dir)

    # # save results
    # import json
    # fp = os.path.join(out_dir, 'test_results.json')
    # with open(fp, 'w') as f:
    #     json.dump(test_res, f)

### Free the GPU and remove large objects

In [75]:
import gc
trainer = trainer.model.to('cpu')
del trainer, tokenizer, datasets
gc.collect()

0