In [None]:
import sys
sys.path.insert(0, '../util')
sys.path.insert(1, '../experiments')

import os
# Disable weights and biases (if installed)
os.environ["WANDB_DISABLED"] = "true"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
from pathlib import Path
import transformers
import datasets
from transformers import AutoModelForTokenClassification, AutoTokenizer, Trainer, TrainingArguments, pipeline, DataCollatorForTokenClassification, EarlyStoppingCallback, trainer_utils
from huggingface_utils import load_custom_dataset, LabelAligner, compute_metrics, eval_on_test_set
from run_experiment import get_train_args
from convert_annotations import entity_values

In [None]:
datasets.logging.set_verbosity_error()
transformers.logging.disable_default_handler()

# Parameters

In [None]:
level = 'fine' # Change to 'coarse' to look at high-level entity classes only
spans = 'long' # Change to 'short' to consider short spans ignoring specifications

In [None]:
config_files = {
    ('coarse' , 'short') : '01_ggponc_coarse_short.yaml',
    ('fine', 'short') : '02_ggponc_fine_short.yaml',
    ('coarse' , 'long' ) : '03_ggponc_coarsee_long.yaml',
    ('fine', 'long' ) : '04_ggponc_fine_long.yaml'
}

In [None]:
import hydra
from hydra import compose, initialize

hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path=Path('..') / 'experiments', job_name='foo')
config = compose(config_name=config_files[(level, spans)], overrides=['cuda=0', 'link=false'])

In [None]:
train_file = config['train_dataset']
dev_file = config['dev_dataset']
test_file = config['test_dataset']

# Setup IOB-encoded dataset with train / dev / test splits

In [None]:
dataset, tags = load_custom_dataset(train=train_file, dev=dev_file, test=test_file, tag_strings=config['task'])

In [None]:
tokenizer = AutoTokenizer.from_pretrained(config['base_model_checkpoint'])
assert isinstance(tokenizer, transformers.PreTrainedTokenizerFast)

In [None]:
label_aligner = LabelAligner(tokenizer)

In [None]:
dataset = dataset.map(lambda e: label_aligner.tokenize_and_align_labels(e, config['label_all_tokens']), batched=True)

In [None]:
id2label = dict(enumerate(tags))
id2label

In [None]:
dataset

# Configure and train ðŸ¤— token classification model

In [None]:
from run_experiment import get_train_args

In [None]:
num_train_epochs = 10 # Remove this line to train for default value of 100 epochs

In [None]:
config['num_train_epochs'] = num_train_epochs

In [None]:
training_args = get_train_args(cp_path='../ner_results', run_name='ner_baseline', report_to=[], **config, resume_from_checkpoint=None)

In [None]:
def model_init():
    return AutoModelForTokenClassification.from_pretrained(
        config['base_model_checkpoint'],
        num_labels=len(tags), 
        id2label=id2label,
    )

data_collator = DataCollatorForTokenClassification(tokenizer)
tr = Trainer(
    args=training_args,
    model_init=model_init,
    train_dataset=dataset["train"],
    eval_dataset=dataset["dev"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics(tags, True),
)

### Train the model

In [None]:
train_result = tr.train()

# Evaluate Model

In [None]:
model = tr.model

In [None]:
from transformers.pipelines.token_classification import AggregationStrategy

In [None]:
pipe = pipeline("ner", model, tokenizer=tokenizer, device=0, aggregation_strategy=AggregationStrategy.FIRST)

In [None]:
#Application to guideline sentence
pipe("""Als Alternative empfiehlt die ASCCP bei zytologischem Verdacht auf CIN 1/2 die sofortige Kolposkopie.""")

In [None]:
# Application to clinical text
pipe("""Antibiose fortsetzen (s. o.), Abstrich erfragen, ggf. Umstellung der Antibiose. Thromboseprophylaxe bis zur sicheren MobilitÃ¤t.""")

In [None]:
test_metrics = eval_on_test_set(dataset["test"], tr, tokenizer, "test")

In [None]:
print(f"""
F1: {test_metrics["test/overall_f1"]:.2f}
 P: {test_metrics["test/overall_precision"]:.2f}
 R: {test_metrics["test/overall_recall"]:.2f}
""")

### Detailed analysis of model performance

See notebook: [03_NER_Analysis](03_NER_Analysis.ipynb)