In [1]:
# imports
import tqdm as notebook_tqdm
import os
import re
import logging
import pandas as pd
from pprint import pprint
import numpy as np
from seqeval.metrics import (precision_score,
                             recall_score,
                             f1_score,
                             classification_report)

from datasets import (load_dataset,
                      DatasetDict, 
                      Features, 
                      Sequence, 
                      ClassLabel, 
                      Value, 
                      interleave_datasets, 
                      get_dataset_config_names, 
                      load_dataset, 
                      load_from_disk
)

from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    DataCollatorForTokenClassification,
    Trainer,
    TrainingArguments,
    set_seed
)

from utility_functions import (split_sources,
                               whitespace_tokens_with_spans,
                               spans_to_bio_labels,
                               build_bio_label_list_from_sources, 
                               compute_metrics_factory, 
                               make_to_features, 
                               process_all,
                               normalize_types
)


  from .autonotebook import tqdm as notebook_tqdm


## 1. Data Loading and Exploration

In [3]:
# We list all available configurations of the dataset:
# configs = get_dataset_config_names("bigbio/swedish_medical_ner")
configs = get_dataset_config_names("community-datasets/swedish_medical_ner")
print("Available configurations:")
for config in configs:
    print(f"- {config}")

Available configurations:
- 1177
- lt
- wiki


In [4]:
# The dataset is loaded with all the chosen configurations

kb_datasets =[]
for config in configs:
    if os.path.isfile(f"data/swedish_medical_ner_{config}"):
        print(f"Loading configuration: {config} from disk")
        ds = load_from_disk(f"data/swedish_medical_ner_{config}")
        print(f"- {ds.config_name}") # Redundant?
        pprint(ds["train"].features) #Display schema    
        kb_datasets.append(ds)
    else:
        print(f"Loading configuration: {config} from the huggingface hub")
        ds = load_dataset("community-datasets/swedish_medical_ner", config)
        ds.config_name = config  # Attach config name as an attribute for later access. Not sure if we need this
        print(f"- {ds.config_name}")
        pprint(ds["train"].features) #Display schema
        kb_datasets.append(ds)
        #Saving to disk
        ds.save_to_disk(f"data/swedish_medical_ner_{ds.config_name}")
###


Loading configuration: 1177 from the huggingface hub
- 1177
{'entities': Sequence(feature={'end': Value(dtype='int32', id=None),
                               'start': Value(dtype='int32', id=None),
                               'text': Value(dtype='string', id=None),
                               'type': ClassLabel(names=['Disorder and Finding',
                                                         'Pharmaceutical Drug',
                                                         'Body Structure'],
                                                  id=None)},
                      length=-1,
                      id=None),
 'sentence': Value(dtype='string', id=None),
 'sid': Value(dtype='string', id=None)}


Saving the dataset (1/1 shards): 100%|██████████| 927/927 [00:00<00:00, 120993.30 examples/s]

Loading configuration: lt from the huggingface hub





- lt
{'entities': Sequence(feature={'end': Value(dtype='int32', id=None),
                               'start': Value(dtype='int32', id=None),
                               'text': Value(dtype='string', id=None),
                               'type': ClassLabel(names=['Disorder and Finding',
                                                         'Pharmaceutical Drug',
                                                         'Body Structure'],
                                                  id=None)},
                      length=-1,
                      id=None),
 'sentence': Value(dtype='string', id=None),
 'sid': Value(dtype='string', id=None)}


Saving the dataset (1/1 shards): 100%|██████████| 745753/745753 [00:00<00:00, 1103035.72 examples/s]


Loading configuration: wiki from the huggingface hub
- wiki
{'entities': Sequence(feature={'end': Value(dtype='int32', id=None),
                               'start': Value(dtype='int32', id=None),
                               'text': Value(dtype='string', id=None),
                               'type': ClassLabel(names=['Disorder and Finding',
                                                         'Pharmaceutical Drug',
                                                         'Body Structure'],
                                                  id=None)},
                      length=-1,
                      id=None),
 'sentence': Value(dtype='string', id=None),
 'sid': Value(dtype='string', id=None)}


Saving the dataset (1/1 shards): 100%|██████████| 48720/48720 [00:00<00:00, 790270.21 examples/s]


In [5]:
# A check
print(len(kb_datasets), "datasets loaded.")

3 datasets loaded.


In [6]:
# The first configuration is explored

pd.set_option('display.max_colwidth', None)

for i, dataset in enumerate(kb_datasets):
    print(f"### Configuration {i + 1}: {dataset.config_name}")
    print(f"Rows: {dataset['train'].num_rows}")  # Shows splits and number of examples
    print(f"Columns: {len(dataset['train'].features)}")  # Number of columns/features

    # Convert a slice of the dataset to a pandas dataframe and display it
    example = dataset["train"].select(range(10)).to_pandas()
    display(example)

### Configuration 1: 1177
Rows: 927
Columns: 3


Unnamed: 0,sid,sentence,entities
0,1177_0,Memantin ( Ebixa ) ger sällan några biverkningar.,"{'start': [9], 'end': [18], 'text': ['Ebixa'], 'type': [0]}"
1,1177_1,Det är också lättare att dosera [ flytande medicin ] än att dela på tabletter.,"{'start': [32], 'end': [52], 'text': ['flytande medicin'], 'type': [1]}"
2,1177_2,( Förstoppning ) är ett vanligt problem hos äldre.,"{'start': [0], 'end': [16], 'text': ['Förstoppning'], 'type': [0]}"
3,1177_3,[ Medicinen ] kan också göra att man blöder lättare eftersom den påverkar { blodets } förmåga att levra sig.,"{'start': [0, 74], 'end': [13, 85], 'text': ['Medicinen', 'blodets'], 'type': [1, 2]}"
4,1177_4,Barn har större möjligheter att samarbeta om de i förväg får veta vad som ska hända.,"{'start': [], 'end': [], 'text': [], 'type': []}"
5,1177_5,Eftersom de påverkar hela kroppen mer än övriga mediciner bör man bara ta dem när olika kombinationer av receptfria mediciner inte hjälper.,"{'start': [], 'end': [], 'text': [], 'type': []}"
6,1177_6,För att få ett skydd mot ( hepatit B ) behövs tre doser vaccin.,"{'start': [25], 'end': [38], 'text': ['hepatit B'], 'type': [0]}"
7,1177_7,Effekten av naproxen sitter i längre och varar cirka 12 timmar jämfört med cirka 6 timmar för ibuprofen.,"{'start': [], 'end': [], 'text': [], 'type': []}"
8,1177_8,[ Cox-hämmare ] finns även som gel och sprej.,"{'start': [0], 'end': [15], 'text': ['Cox-hämmare'], 'type': [1]}"
9,1177_9,"Det är bra om ett litet barn är mätt och utsövt, eftersom de flesta påfrestningar då känns mindre.","{'start': [], 'end': [], 'text': [], 'type': []}"


### Configuration 2: lt
Rows: 745753
Columns: 3


Unnamed: 0,sid,sentence,entities
0,lt_0,", (hjärtinfarkt) och (syndrom) som vi nu år 1999 inte ens vet na","{'start': [2, 21], 'end': [16, 30], 'text': ['hjärtinfarkt', 'syndrom'], 'type': [0, 0]}"
1,lt_1,"tinernas goda effekt på morbiditeten är välkänd, och data hi","{'start': [], 'end': [], 'text': [], 'type': []}"
2,lt_2,"[sukralfat], [lakrits] och vismut) som kunde utgöra ett skydd öv","{'start': [0, 13], 'end': [11, 22], 'text': ['sukralfat', 'lakrits'], 'type': [1, 1]}"
3,lt_3,och tveksamhet {vad} gäller operationsindikationen kan man ha,"{'start': [16], 'end': [21], 'text': ['vad'], 'type': [2]}"
4,lt_4,1989 blev en anmälningspliktig (sjukdom) enligt Smittskyddsla,"{'start': [32], 'end': [41], 'text': ['sjukdom'], 'type': [0]}"
5,lt_5,kombinerat med remodellering av (hjärtat). Detta säkras genom,"{'start': [32], 'end': [41], 'text': ['hjärtat'], 'type': [0]}"
6,lt_6,olyckshändelse radikalt förändrat deras liv. {Sigmoideum} är,"{'start': [46], 'end': [58], 'text': ['Sigmoideum'], 'type': [2]}"
7,lt_7,ra att hon samtidigt ordinerade [Cyklokapron] i en mängd av 5,"{'start': [32], 'end': [45], 'text': ['Cyklokapron'], 'type': [1]}"
8,lt_8,till vara erfarenheterna och föra ut kunskapen till sjukvård,"{'start': [], 'end': [], 'text': [], 'type': []}"
9,lt_9,es kring behandling med betablockad vid (kronisk hjärtsvikt).,"{'start': [40], 'end': [60], 'text': ['kronisk hjärtsvikt'], 'type': [0]}"


### Configuration 3: wiki
Rows: 48720
Columns: 3


Unnamed: 0,sid,sentence,entities
0,wiki_0,"{kropp} beskrivs i till exempel människokroppen, anatomi och f","{'start': [0], 'end': [7], 'text': ['kropp'], 'type': [2]}"
1,wiki_1,"sju miljoner år gammalt hominint {kranium}, klassificerad som","{'start': [33], 'end': [42], 'text': ['kranium'], 'type': [2]}"
2,wiki_2,autosomer och ett par könskromosomer. Varje {kromosom} består,"{'start': [45], 'end': [55], 'text': ['kromosom'], 'type': [2]}"
3,wiki_3,{kromosom} består av en DNA-molekyl och {protein}. En DNA-molek,"{'start': [1], 'end': [50], 'text': ['kromosom} består av en DNA-molekyl och {protein'], 'type': [2]}"
4,wiki_4,tikel:Människans {skelett} Människans skelett är det skelett s,"{'start': [17], 'end': [26], 'text': ['skelett'], 'type': [2]}"
5,wiki_5,os människor. En vuxen människas {skelett} består av 206 till,"{'start': [33], 'end': [42], 'text': ['skelett'], 'type': [2]}"
6,wiki_6,"{lett} består av 206 till 220 {ben}, beroende på hur man räknar.","{'start': [0], 'end': [35], 'text': ['lett} består av 206 till 220 {ben'], 'type': [2]}"
7,wiki_7,v kroppsvikten.Ett nyfött barn har ca 300 {ben} i kroppen vilk,"{'start': [42], 'end': [47], 'text': ['ben'], 'type': [2]}"
8,wiki_8,kollektivet i mindre bitar såsom länder > städer > orter {Hud},"{'start': [57], 'end': [62], 'text': ['Hud'], 'type': [2]}"
9,wiki_9,sdjur. {Huden} utgör ett mekaniskt skydd mot omvärlden och bid,"{'start': [7], 'end': [14], 'text': ['Huden'], 'type': [2]}"


The data of interest are: the Passage text, the Named entities and their types.
Named entities consists of spans, a contiguous sequence of tokens (words, subwords, or characters) in a text that together represent a single entity.

In [7]:
# 1) Split each config into train/val
per_source_raw= split_sources(kb_datasets, val_fraction=0.05, seed=42)
print(per_source_raw)

{'1177': DatasetDict({
    train: Dataset({
        features: ['sid', 'sentence', 'entities', 'source'],
        num_rows: 880
    })
    validation: Dataset({
        features: ['sid', 'sentence', 'entities', 'source'],
        num_rows: 47
    })
}), 'lt': DatasetDict({
    train: Dataset({
        features: ['sid', 'sentence', 'entities', 'source'],
        num_rows: 708465
    })
    validation: Dataset({
        features: ['sid', 'sentence', 'entities', 'source'],
        num_rows: 37288
    })
}), 'wiki': DatasetDict({
    train: Dataset({
        features: ['sid', 'sentence', 'entities', 'source'],
        num_rows: 46284
    })
    validation: Dataset({
        features: ['sid', 'sentence', 'entities', 'source'],
        num_rows: 2436
    })
})}


In [8]:
# 2) Build global BIO labels (union over configs)
# label_list = build_bio_label_list_from_sources(per_source_raw)
# label2id = {l: i for i, l in enumerate(label_list)}
# id2label = {i: l for l, i in label2id.items()}
label_list = build_bio_label_list_from_sources(per_source_raw)
label2id = {l: i for i, l in enumerate(label_list)}
id2label = {i: l for l, i in label2id.items()}
print(label_list)
print(label2id)
print(id2label)


['O', 'B-body_structure', 'I-body_structure', 'B-disorder_finding', 'I-disorder_finding', 'B-pharmaceutical_drug', 'I-pharmaceutical_drug']
{'O': 0, 'B-body_structure': 1, 'I-body_structure': 2, 'B-disorder_finding': 3, 'I-disorder_finding': 4, 'B-pharmaceutical_drug': 5, 'I-pharmaceutical_drug': 6}
{0: 'O', 1: 'B-body_structure', 2: 'I-body_structure', 3: 'B-disorder_finding', 4: 'I-disorder_finding', 5: 'B-pharmaceutical_drug', 6: 'I-pharmaceutical_drug'}


In [9]:
# 3) Model & tokenizer
tokenizer = AutoTokenizer.from_pretrained("KB/bert-base-swedish-cased", use_fast=True)
model = AutoModelForTokenClassification.from_pretrained(
    "KB/bert-base-swedish-cased",
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id,
)
collator = DataCollatorForTokenClassification(tokenizer=tokenizer)
compute_metrics = compute_metrics_factory(id2label=id2label)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at KB/bert-base-swedish-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
per_source_norm = {}
for cfg, ds in per_source_raw.items():
    per_source_norm[cfg] = ds.map(
        lambda x: x,  # no-op to get a copy
    )
    for split in per_source_norm[cfg].keys():
        per_source_norm[cfg][split] = normalize_types(per_source_norm[cfg][split])


Map: 100%|██████████| 880/880 [00:00<00:00, 13352.73 examples/s]
Map: 100%|██████████| 47/47 [00:00<00:00, 5743.11 examples/s]
Map: 100%|██████████| 880/880 [00:00<00:00, 10560.58 examples/s]


ValueError: Invalid string class label pharmaceutical_drug

In [None]:
# 4) Convert span schema → token-classification features
to_features = make_to_features(tokenizer, label2id, max_length=256)
per_source = process_all(per_source_raw, to_features)

In [None]:
# 5) Prepare TrainingArguments base
base_args = dict(
    output_dir="outputs/ner_kbbert_multi",
    learning_rate=2e-5, # Only used for interleave mode."
    num_train_epochs=2.0, # Only used for interleave mode."
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    weight_decay=0.01,
    warmup_ratio=0.1,
    gradient_accumulation_steps=1,
    logging_steps=50,
    save_total_limit=2,
    load_best_model_at_end=False,
    seed=42,
    report_to=[],  # disable HF trackers by default
    fp16=False, # fp16=True,
    bf16=False, # bf16=True,
    gradient_checkpointing=1,
)

In [None]:
# Staged curriculum params (epochs & LR per stage)
stage_lt_epochs = 1.0
stage_lt_lr = 2e-5
stage_wiki_epochs=1.0
stage_wiki_lr=1e-5
stage_1177_epochs=2.0
stage_1177_lr = 5e-6

In [None]:
# 6) Train
# Staged curriculum: lt -> wiki -> 1177 (default order if present)
order = []
# map short names to full config ids if present
short2cfg = { "lt":None, "wiki":None, "1177":None }
cfg= None
for kb_dataset in kb_datasets:
    if "lt" in kb_dataset.config_name: short2cfg["lt"] = cfg
    elif "wiki" in kb_dataset.config_name: short2cfg["wiki"] = cfg
    elif "1177" in kb_dataset.config_name: short2cfg["1177"] = cfg
if short2cfg["lt"]:   order.append(("lt", short2cfg["lt"], stage_lt_epochs, stage_lt_lr))
if short2cfg["wiki"]: order.append(("wiki", short2cfg["wiki"], stage_wiki_epochs, stage_wiki_lr))
if short2cfg["1177"]: order.append(("1177", short2cfg["1177"], stage_1177_epochs, stage_1177_lr))
if not order:
    raise ValueError("Could not infer stages from dataset_configs. Include lt/wiki/1177 source configs.")

for stage_name, cfg, epochs, lr in order:
    #logger.info(f"\n=== Stage: {stage_name} on {cfg} | epochs={epochs} lr={lr} ===")
    stage_args = TrainingArguments(
        **{**base_args,
            "num_train_epochs": epochs,
            "learning_rate": lr,
            "output_dir": os.path.join(base_args["output_dir"], f"stage_{stage_name}")},
    )
    trainer = Trainer(
        model=model,
        args=stage_args,
        train_dataset=per_source[cfg]["train"],
        eval_dataset=per_source[cfg]["validation"],
        tokenizer=tokenizer,             # (works on 4.x; future deprec warns ok)
        data_collator=collator,
        compute_metrics=compute_metrics,
    )
    trainer.train()
    metrics = trainer.evaluate()
    model = trainer.model