# Fine-tune token classifier for social group mention detection and extraction

*author*: **Hauke Licht** (hauke.licht@uibk.ac.at)

<br>

In this notebook, we fine-tune a transformer encoder for social group mention detection through supervised token classificier using annotations from

> Licht H, Sczepanski R. Detecting Group Mentions in Political Rhetoric: A Supervised Learning Approach. *forthcoming*. *The British Journal of Political Science*. doi:[10.31219/osf.io/ufb96](https://doi.org/10.31219/osf.io/ufb96)

In particular, we will use the labels we have obtained from two coders group mention annotations of of UK party manifestos: https://github.com/haukelicht/group_mention_detection/blob/main/replication/data/annotation/labeled/uk-manifestos_all_labeled.jsonl (but see also the other two data files in the repository)

<!-- <a target="_blank" href="https://colab.research.google.com/github/haukelicht/comptext25_task_type_toolkit_tutorial/blob/main/span_extraction/token_classifier_finetuning.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>  -->

## 1. Setup

### Setup Colab (if needed)

In [None]:
# check if on Colab
ON_COLAB = True
try:
  from google import colab
except:
  ON_COLAB = False

In [None]:
# install required packages
if ON_COLAB:
    !pip install -q nltk==3.9.1 accelerate~=1.5.0 datasets==3.5.0 tokenizers==0.21.1 transformers~=4.51.3 scikit-learn==1.6.1 seqeval==1.2.2
    !pip install -q --upgrade --force-reinstall --no-deps git+https://github.com/haukelicht/soft-seqeval.git@main

In [None]:
# download the data
if ON_COLAB:
    !mkdir -p data
    GITHUB_PATH="https://raw.githubusercontent.com/haukelicht/group_mention_detection/refs/heads/main/replication/data/annotation/labeled/"
    !wget -O data/uk-manifestos_all_labeled.jsonl -q $GITHUB_PATH/uk-manifestos_all_labeled.jsonl

### Define the arguments

*Note:* 
I'm using `types.SimpleNamespace` here so that the object behaves similar to the output of `argparse.ArgumentParser().parse_args()`.
This way, it's very easy to convert this notebook to an executable python script.

In [None]:
from pathlib import Path
from types import SimpleNamespace

args = SimpleNamespace()

args.data_file = Path('data/uk-manifestos_all_labeled.jsonl') if ON_COLAB else Path('../replication/data/annotation/labeled/uk-manifestos_all_labeled.jsonl')
args.dev_size = 0.1
args.test_size = 0.2

# model name in huggingface model hub 
args.model_name = "answerdotai/ModernBERT-base"  # "answerdotai/ModernBERT-large"

# path where to save temporary files and the final model (if desired)
args.output_path = Path('results/finetuning')

## hyperparameters
args.epochs=10
args.learning_rate=4e-5
args.train_batch_size=16
args.gradient_accumulation_steps=2
args.eval_batch_size=32
args.weight_decay=0.3

## early stopping
args.early_stopping = True
args.metric = 'seqeval-social group_f1' # metric used for early stopping and to select the best model among saved checkpoints after stopping
args.early_stopping_patience = 3
args.early_stopping_threshold = 0.03

# for rerpoducibility
args.seed = 42

args.save_finetuned_model = False

### Load required libraries

In [None]:
import shutil
import numpy as np
import pandas as pd
import json

from datasets import Dataset, DatasetDict

import torch
import transformers
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification,
    EarlyStoppingCallback,
    set_seed,
)
set_seed(args.seed)
# uncomment the next two lines if you want to suppress the logging output
# from transformers.utils import logging
# logging.set_verbosity_error()

from soft_seqeval.metrics import compute_sequence_metrics

### Define custom helper functions

In [None]:
import json
from typing import Dict, Any, List, Union
def read_jsonl(path: Union[Path, str], replace_newlines: bool = False) -> List[Dict[str, Any]]:
    """
    Read jsonlines from `path`, supporting .zip and .gz files.
    """
    # handle regular files
    with open(path) as infile:
        if not replace_newlines:
            return [json.loads(line) for line in infile if line]
        else:
            return [json.loads(line.replace("\\n", " ")) for line in infile if line]

In [None]:
# see also: https://github.com/haukelicht/group_mention_detection/blob/main/replication/code/utils/classification.py
def tokenize_and_align_sequence_labels(examples, tokenizer, **kwargs) -> Dict:
    # source: simplied from  https://github.com/huggingface/transformers/blob/730a440734e1fb47c903c17e3231dac18e3e5fd6/examples/pytorch/token-classification/run_ner.py#L442
    tokenized_inputs = tokenizer(examples['tokens'], is_split_into_words=True, **kwargs)

    labels = []
    for i, label in enumerate(examples['labels']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs['labels'] = labels
    return tokenized_inputs

In [None]:
from datasets import Dataset
# see also: https://github.com/haukelicht/group_mention_detection/blob/main/replication/code/utils/classification.py
def create_token_classification_dataset(
    data: List[Dict], 
    tokens_field: str='tokens',
    labels_field: Union[None, str]='labels'
):
    dataset = Dataset.from_list(data)
    if tokens_field != 'tokens':
        dataset = dataset.rename_column(tokens_field, 'tokens')
    if labels_field is not None and labels_field != 'labels':
        dataset = dataset.rename_column(labels_field, 'labels')
    required = ['tokens'] if labels_field is None else ['tokens', 'labels']
    rm = [c for c in dataset.column_names if c not in required]
    if len(rm) > 0:
        dataset = dataset.remove_columns(rm)
    return dataset

## 2. Load and prepare the data


Our data files at https://github.com/haukelicht/group_mention_detection/blob/main/replication/data/annotation/labeled/ contain the token-level labels aggregated from our two coders' group mention annotations.

Each file is a [JSONlines file](https://jsonlines.org/), that is, a text file with one JSON dictionary per line.

An exemplary line looks like this:

```json
{
 "id": "829ac29cd9304a66265e3ea830a505e3",
 "text": "Seit 150 Jahren machen wir Politik für eine bessere Gesellschaft .",
 "tokens": [
  "Seit",
  "150",
  "Jahren",
  "machen",
  "wir",
  "Politik",
  "für",
  "eine",
  "bessere",
  "Gesellschaft",
  "."],
 "annotations": {"emarie": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]},
 "labels": {"BSCModel": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]},
 "metadata": {"sentence_id": "41320.000.2013.1.1-2520-1",
  "split_": "smarie",
  "job": "group-mentions-annotation-de-manifestos-round-01"}
}
```

Note that each dictionary records 

- the text in pre-tokenized format,
- annotations and labels at the _token_ level (a dictionary of lists, one per annotator),

In particular, 

- Annotations are in the field `"annotations"` and map annotator IDs to their token-level annotations.
- **Labels** are in the field `"labels"`, and `"BSCModel"` records the Bayesian Sequence Combination (BSC) model-based aggregate labels.


In [None]:
def parse_record(d):
    return {'id': d['id'], 'tokens': d['tokens'], 'labels': d['labels']['BSCModel']}

data = read_jsonl(args.data_file)

data = [parse_record(d) for d in data]

In [None]:
doc = data[5]
for t, l in zip(doc['tokens'], doc['labels']):
    print(f"{t} ==> {l}")

In [None]:
# show available label IDs
set(l for d in data for l in d['labels'])
# NOTE: implies 5 types (0 reserved for outside token and otherwise 2 IDs (B and I) per group type)

The records in `"labels"` are numeric label IDs that map onto our **group types**:

- social group
- political group
- political institution
- organization, public institution, or collective actor
- implicit social group reference

Here is how to convert them to text labels:

In [None]:
# get list of entity types
types = [
  "social group",
  "political group",
  "political institution",
  "organization, public institution, or collective actor",
  "implicit social group reference",
]
# convert to IOB2 scheme
scheme = ['O'] + ['I-'+t for t in types] + ['B-'+t for t in types]
# map label type indicators to label IDs
label2id = {l: i for i, l in enumerate(scheme)}
# and vice versa
id2label = {i: l for i, l in enumerate(scheme)}
NUM_LABELS = len(label2id)

label2id
# NOTE: the span-level annotations will be converted to token-level annotations using the IOB2 scheme.append
#       This means that 
#        - a word that are not part of any entity will be labeled as "O",
#        - a word at the beginning of a span will be labeled as "B-<entity_type>", and 
#        - a word inside a span will be labeled as "I-<entity_type>"

### Split the data

In [None]:
from typing import Optional, Union
from sklearn.model_selection import train_test_split
def split_data(
        data: List[Dict],
        test_size: Union[None, float, int]=0.2,
        dev_size: Union[None, float, int]=0.2,
        stratify_by: Optional[Union[str, List[str]]]=None,
        seed: int=42,
        return_dict: bool=False
    ):
    """Split a cropus into training, development, and test sets.

    Args:

    df: List[Dict]
        The corpus to split. Must be a list of dictionaries.
    dev_size: float
        The proportion of the data to include in the development set.
    test_size: float
        The proportion of the data to include in the test set.
    stratify_by: str or list of str, optional
        Metadata field(s) to use for stratified splitting. If a single field is 
        provided, the data will be stratified by the values in that field in the metadata. 
        If multiple columns are provided, the data will be stratified by 
        the unique combinations of values of these fields in the metadata.
    seed: int
        Random seed for reproducibility.
    return_dict: bool
        Whether to return the splits as a dictionary.
    """
    n = len(data)
    
    if stratify_by:
        assert all('metadata' in doc for doc in data), "Stratification requires 'metadata' field in each document's dictionary"
        if isinstance(stratify_by, str):
            stratify_by = [stratify_by]
        for field in stratify_by:
            assert all(field in doc['metadata'] for doc in data), f"Field '{field}' not found in 'metadata' of all documents"
        # create a grouping indicator based on the stratification columns
        strata = ['__'.join([str(doc['metadata'][field]) for field in stratify_by]) for doc in data]
    else:
        strata = None
        
    idxs = list(range(n))
    tmp, test_idxs = train_test_split(idxs, test_size=test_size, random_state=seed, stratify=strata)# if test_size > 0 else (idxs, [])
    strata = [strata[i] for i in test_idxs] if stratify_by else None
    train_idxs, dev_idxs = train_test_split(tmp, test_size=dev_size, random_state=seed, stratify=strata)# if dev_size > 0 else (idxs, [])

    train, dev, test = [data[i] for i in train_idxs], [data[i] for i in dev_idxs], [data[i] for i in test_idxs]
    
    if return_dict:
        return {'train': train, 'dev': dev, 'test': test}
    else:
        return train, dev, test

In [None]:
data_splits = split_data(data, test_size=args.test_size, dev_size=args.dev_size, seed=args.seed, return_dict=True)

In [None]:
# load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True, add_prefix_space=True)
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

# apply the custom function defined above to set subword tokens' labels to -100
# this is necessary because the tokenization may split a word into multiple subwords
datasets = DatasetDict({split: create_token_classification_dataset(data) for split, data in data_splits.items()})
datasets = datasets.map(lambda example: tokenize_and_align_sequence_labels(example, tokenizer=tokenizer), batched=True)

In [None]:
datasets.num_rows

In [None]:
# uncomment to show example
example = datasets['train'][2]
for t, l in zip(example['input_ids'], example['labels']):
    if t == tokenizer.pad_token_id:
        break
    print(l, '\t', repr(tokenizer.decode(t)))

In [None]:
# NOTE: after tokenization, text tokens are represented with their token IDs
#        so we can remove them from the dataset (need to load these to the GPU)
datasets = datasets.remove_columns(['tokens']) 

## Prepare fine-tuning

In [None]:
# NOTE: the `model_init` function is used by the Trainer to initialize the model
#   and is called each time before training starts.
#  So we define it here to load the model from the Huggingface model hub
#   and set the number of labels to the number of unique labels in the dataset
#   and the label2id and id2label mappings
def model_init():
    config = AutoConfig.from_pretrained(args.model_name)
    config.num_labels = NUM_LABELS
    config.label2id = label2id
    config.id2label = id2label
    return AutoModelForTokenClassification.from_pretrained(args.model_name, config=config, device_map='auto')

In [None]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)
    # convert predictions and labels to list of lists of ints
    predictions = predictions.astype(int).tolist()
    labels = labels.astype(int).tolist()
    return compute_sequence_metrics(y_true=labels, y_pred=predictions, id2label=id2label, flatten_output=True)
# NOTE: the `compute_metrics` function is used by the Trainer to compute the evaluation metrics 

In [None]:
# NOTE: at the beginning of the script, we have defined args.metric as the metric to be used for early stopping
#       and model selection among saved checkpoints after stopping
#       This metric must be available in the output of our `compute_metrics` function defined above
#       So let's check this

ex = ['O', 'B-social group', 'I-social group', 'O']
scores = compute_sequence_metrics([ex], [ex], id2label, flatten_output=True)
if args.metric not in scores.keys():
    raise ValueError(f"Invalid metric: {args.metric}, valid metrics are: {', '.join(scores.keys())}")

In [None]:
# look at metrics computed by the `compute_sequence_metrics` function provided by my soft-seqeval package
# NOTE: the pattern is <scheme>_<metric>_<entity_type> where 
#       - <scheme> is the scheme is 
#           - "seqeval": strict seqeval metric as implemented in https://github.com/chakki-works/seqeval
#           - "softseqeval": soft-seqeval metrics implemented as described here https://github.com/haukelicht/soft-seqeval/blob/main/notebooks/available_metrics.ipynb
#           - "wordlevel": evaluation at word level
#           - "doclevel": evaluation at document (i.e., sentence) level
#       - <metric> is the metric used for evaluation (f1, precision, or recall)
#       - <entity_type> is the entity type (here, "social group" or "other") or "macro" or "micro" for macro or micro average over all entity types
metrics_overview = pd.DataFrame(list(scores.keys()), columns=['key'])
metrics_overview[['scheme', 'metric']] = metrics_overview.key.str.split('-', n=2, expand=True)
metrics_overview[['type', 'metric']] = metrics_overview.metric.str.split('_', n=2, expand=True)
metrics_overview

**Note:** Because we are most interested in detection (exact) occurrences of social group mentions, we will use `seqeval-social group_f1`


### Define the training arguments

In [None]:
out_dir = args.output_path
checkpoints_dir = out_dir / 'checkpoints'
logs_dir = out_dir / 'logs'

training_args = TrainingArguments(
    
    # hyperparameters
    num_train_epochs=args.epochs,
    learning_rate=args.learning_rate,
    per_device_train_batch_size=args.train_batch_size,
    gradient_accumulation_steps=args.gradient_accumulation_steps,
    per_device_eval_batch_size=args.eval_batch_size,
    weight_decay=args.weight_decay,
    optim='adamw_torch',
    
    # when to evaluate
    eval_strategy='epoch',
    # how to select "best" model
    do_eval=bool('dev' in datasets),
    metric_for_best_model=args.metric,
    load_best_model_at_end=True,
    # when to save
    save_strategy='epoch',
    save_total_limit=2 if 'dev' in datasets else None, # don't save all model checkpoints
    # where to store results
    output_dir=checkpoints_dir,
    overwrite_output_dir=True,
    
    # logging
    logging_dir=logs_dir,
    logging_strategy='epoch',
    report_to='none',
    
    # reproducibility
    seed=args.seed,
    data_seed=args.seed,
    full_determinism=True
)


# build callbacks
callbacks = []
if args.early_stopping:
    if 'dev' not in datasets:
        raise ValueError('Early stopping requires a dev data set')
    callbacks.append(EarlyStoppingCallback(early_stopping_patience=args.early_stopping_patience, early_stopping_threshold=args.early_stopping_threshold))


## Fine-tuning

### Create the trainer

In [None]:
trainer = Trainer(
    model_init=model_init,
    args=training_args,
    train_dataset=datasets['train'],
    eval_dataset=datasets['dev'],
    processing_class=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer),
    compute_metrics=compute_metrics,
    callbacks=callbacks
)

### Fine-tune

In [None]:
print('Training ...')
train_hist = trainer.train()

## Evaluate

In [None]:
# apply the best model loaded after finishing training to the test set
print('Evaluating ...')
test_res = trainer.evaluate(datasets['test'], metric_key_prefix='test')

In [None]:
# create a more nice-to-loook-at output
out = pd.DataFrame(test_res, index=['value']).T
out = out.reset_index().rename(columns={'index': 'cat'})
out[['set', 'scheme', 'metric', 'misc']] = out.cat.str.split('_', expand=True)
out = out[out.misc.isnull()]
out = out[out.metric.notnull()]
out[['scheme', 'type']] = out.scheme.str.split('-', expand=True)
out = out.drop(columns=['set', 'cat', 'misc'])
out = out[['scheme', 'type', 'metric', 'value']]
out = out.pivot(index=['type', 'scheme', ], columns='metric', values='value')
keys = [
    (typ, scheme)
    for typ in types
    for scheme in ['seqeval', 'softseqeval', 'wordlevel', 'doclevel']
]
out.loc[keys, :]

## Inference

In [None]:
from datasets import Dataset

from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset

from tqdm import tqdm

We'll use the transformer's `pipeline` for inference (i.e., predicting spans in unlabeled data).

Specifically, we use the **NER** (named entity recognition) task and pass the fine-tuned model from the trainer.

In [None]:
extractor = pipeline(task='ner', model=trainer.model, tokenizer=tokenizer, batch_size=32, aggregation_strategy='simple')

In [None]:
docs = read_jsonl(args.data_file)

# take fits 500 sentences just for illustrative purposes
docs = docs[:500]

# keep only text and id fields
docs = [{f: doc[f] for f in ['text', 'id']} for doc in docs]

# for batch inference (see https://huggingface.co/docs/transformers/pipeline_tutorial#batch-inference)
kd = KeyDataset(Dataset.from_list(docs), 'text')

In [None]:
# apply the extractor to the dataset
pred_ents = [p for p in tqdm(extractor(kd), total=len(docs))]

For each text in the list of texts taken from `docs`, we get a list of dictionaries, here called `pred_ents`.

Each item in `pred_ents` is a dictionary with the following fields:

- start: character start index of the entity in the text
- end: character end index of the entity in the text
- score: confidence score of the prediction
- word: the text of the entity
- entity_group: the entity type (e.g., 'social group')


Let's use convert these annotations into one `Entities` instance and create a new `LabeledSequence` instance from this information for each text:  

In [None]:
from soft_seqeval.classes import Entity, Entities
from soft_seqeval.classes import LabeledSequence
from copy import deepcopy

def pipeline_output_to_entities(pred) -> Entities:
    """Take output from the NER pipeline and convert to Entities instance"""
    ents = []
    for ent in pred:
        ent = deepcopy(ent)
        if ent['word'][0] == ' ':
            ent['start'] += 1
        if ent['word'][-1] == ' ':
            ent['end'] -= 1
        ents.append(Entity(ent['start'], ent['end'], ent['entity_group']))
    return Entities(ents)

# iterate over the documents and predicted annotations to create a list of LabeledSequence instances
preds = [
    LabeledSequence(text=doc['text'], entities=pipeline_output_to_entities(pred), id=doc['id'], lang='english')
    for doc, pred in zip(docs, pred_ents)
]

In [None]:
# look at first 10 examples
preds[:10]


## Finally

#### Delete intermediate checkpoints and log files

In [None]:
# finally: clean up
if checkpoints_dir.exists():
    shutil.rmtree(checkpoints_dir)
if logs_dir.exists():
    shutil.rmtree(logs_dir)

#### Save the best model (if desired)

In [None]:
if args.save_finetuned_model:
    trainer.save_model(out_dir)
    tokenizer.save_pretrained(out_dir)

    # # save results
    # import json
    # fp = os.path.join(out_dir, 'test_results.json')
    # with open(fp, 'w') as f:
    #     json.dump(test_res, f)

### Free the GPU and remove large objects

In [None]:
import gc
trainer = trainer.model.to('cpu')
del trainer, tokenizer, datasets
gc.collect()