In [2]:
import os
import random
from typing import Dict, List, Tuple, Union

In [3]:
import matplotlib.colors as mcolors  # красиво раскрасим наши именованные сущности
from nltk.tokenize.treebank import TreebankWordDetokenizer
from spacy import displacy
import evaluate
import numpy as np
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import DataCollatorForTokenClassification, TrainingArguments, Trainer
from transformers import pipeline

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
DATASET_NAME = 'adsabs/WIESP2022-NER'

In [5]:
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

In [6]:
trainset = load_dataset(DATASET_NAME, split='train')

In [7]:
label_set = set()
for it in trainset['ner_tags']:
    label_set |= set(it)
label_list = ['O'] + sorted(list(label_set - {'O'}))

In [7]:
for it in label_list: print(it)

O
B-Archive
B-CelestialObject
B-CelestialObjectRegion
B-CelestialRegion
B-Citation
B-Collaboration
B-ComputingFacility
B-Database
B-Dataset
B-EntityOfFutureInterest
B-Event
B-Fellowship
B-Formula
B-Grant
B-Identifier
B-Instrument
B-Location
B-Mission
B-Model
B-ObservationalTechniques
B-Observatory
B-Organization
B-Person
B-Proposal
B-Software
B-Survey
B-Tag
B-Telescope
B-TextGarbage
B-URL
B-Wavelength
I-Archive
I-CelestialObject
I-CelestialObjectRegion
I-CelestialRegion
I-Citation
I-Collaboration
I-ComputingFacility
I-Database
I-Dataset
I-EntityOfFutureInterest
I-Event
I-Fellowship
I-Formula
I-Grant
I-Identifier
I-Instrument
I-Location
I-Mission
I-Model
I-ObservationalTechniques
I-Observatory
I-Organization
I-Person
I-Proposal
I-Software
I-Survey
I-Tag
I-Telescope
I-TextGarbage
I-URL
I-Wavelength


In [8]:
entity_classes = sorted(list(set(
    map(
        lambda it2: it2[2:],
        filter(
            lambda it1: it1 != 'O',
            label_list
        )
    )
)))

In [9]:
entity_colors = [mcolors.rgb2hex((0.5 + random.random() / 2, 0.5 +random.random() / 2, 0.5 +random.random() / 2))
                 for _ in range(len(entity_classes))]

In [10]:
def bio_to_spans(bio: List[str]) -> List[Dict[str, Union[int, str]]]:
    # Функция нужна для красивого изображения текста с именованными сущностями
    # с использованием displacy https://spacy.io/usage/visualizers в стиле "span".
    ne_tag = ''
    start_pos = -1
    bounds = []
    for idx, val in enumerate(bio):
        if val.upper() == 'O':
            if start_pos >= 0:
                bounds.append({
                    'start_token': start_pos,
                    'end_token': idx,
                    'label': ne_tag
                })
            start_pos = -1
            ne_tag = ''
        elif val.upper().startswith('B-'):
            if start_pos >= 0:
                bounds.append({
                    'start_token': start_pos,
                    'end_token': idx,
                    'label': ne_tag
                })
            start_pos = idx
            ne_tag = val[2:]
    if start_pos >= 0:
        bounds.append({
            'start_token': start_pos,
            'end_token': len(bio),
            'label': ne_tag
        })
    return bounds

In [11]:
sample_for_rendering = {
    'text': TreebankWordDetokenizer().detokenize(trainset[0]['tokens']),
    'spans': bio_to_spans(trainset[0]['ner_tags']),
    'tokens': trainset[0]['tokens'],
}

In [12]:
rendered = displacy.render(
    sample_for_rendering, style='span',
    options={'ents': entity_classes, 'colors': dict(zip(entity_classes, entity_colors))},
    manual=True, jupyter=True
)

In [13]:
def bio_to_ent(tokens: List[str], bio: List[str]) -> Tuple[str, List[Dict[str, Union[int, str]]]]:
    # Функция нужна для ещё более красивого изображения текста с именованными сущностями
    # с использованием displacy https://spacy.io/usage/visualizers в стиле "ent".
    if len(tokens) != len(bio):
        err_msg = f'Tokens do not correspond to their labels: {len(tokens)} != {len(bio)}!'
        raise RuntimeError(err_msg)
    full_text = TreebankWordDetokenizer().detokenize(tokens)
    token_bounds = []
    previous_pos = 0
    for cur in tokens:
        found_idx = full_text[previous_pos:].find(cur)
        if found_idx < 0:
            err_msg = f'The token {cur} is not found in the text "{full_text}".'
            raise RuntimeError(err_msg)
        token_start = found_idx + previous_pos
        token_end = token_start + len(cur)
        token_bounds.append((token_start, token_end))
        previous_pos = token_end
    entity_spans = bio_to_spans(bio)
    entity_bounds = []
    for cur in entity_spans:
        entity_class = cur['label']
        entity_start = token_bounds[cur['start_token']][0]
        entity_end = token_bounds[cur['end_token'] - 1][1]
        entity_bounds.append({
            'start': entity_start,
            'end': entity_end,
            'label': entity_class
        })
    del token_bounds, entity_spans
    return full_text, entity_bounds

In [14]:
sample_for_rendering_2 = dict(zip(
    ('text', 'ents'),
    bio_to_ent(trainset[0]['tokens'], trainset[0]['ner_tags'])
))

In [15]:
rendered_2 = displacy.render(
    sample_for_rendering_2, style='ent',
    options={'ents': entity_classes, 'colors': dict(zip(entity_classes, entity_colors))},
    manual=True, jupyter=True
)

In [16]:
MODEL_NAME = 'FacebookAI/xlm-roberta-base'

In [17]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [18]:
example = trainset[0]
tokenized_input = tokenizer(example['tokens'], is_split_into_words=True)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input['input_ids'])
print(tokens)

Token indices sequence length is longer than the specified maximum sequence length for this model (816 > 512). Running this sequence through the model will result in indexing errors


['<s>', '▁W', 'hil', 'st', '▁a', '▁reasonable', '▁harmoni', 'c', '▁fit', '▁to', '▁the', '▁ESP', 'a', 'DO', 'n', 'S', '▁data', '▁can', '▁be', '▁achieve', 'd', '▁using', '▁this', '▁period', ',', '▁it', '▁does', '▁not', '▁produce', '▁an', '▁acceptable', '▁pha', 'sing', '▁of', '▁all', '▁available', '▁', '〈', '▁B', '▁z', '▁', '〉', '▁measure', 'ments', '.', '▁Figur', 'e', '▁1.', '▁Photo', 'metric', '▁(', '▁top', '▁)', '▁and', '▁magnetic', '▁', '〈', '▁B', '▁z', '▁', '〉', '▁(', '▁bottom', '▁)', '▁measure', 'ments', ',', '▁phase', 'd', '▁with', '▁period', 's', '▁determine', 'd', '▁from', '▁(', 'le', 'ft', '▁to', '▁right', ')', '▁K', '2', '▁photo', 'met', 'ry', ',', '▁all', '▁', '〈', '▁B', '▁z', '▁', '〉', '▁measure', 'ments', ',', '▁and', '▁all', '▁photo', 'metric', '▁measure', 'ments', '.', '▁', '〈', '▁B', '▁z', '▁', '〉', '▁measure', 'ments', '▁were', '▁obtain', 'ed', '▁from', '▁ESP', 'a', 'DO', 'n', 'S', '▁by', '▁Shu', 'lt', 'z', '▁et', '▁al', '.', '▁(', '▁2018', '▁)', '▁and', '▁photo', 'po', 

In [19]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples['tokens'], truncation=True, is_split_into_words=True)
    labels = []
    for i, label in enumerate(examples['ner_tags']):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:  # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:  # Only label the first token of a given word.
                label_ids.append(label_list.index(label[word_idx]))
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs['labels'] = labels
    return tokenized_inputs

In [20]:
tokenized_trainset = trainset.map(tokenize_and_align_labels, batched=True)

In [21]:
for k in sorted(tokenized_trainset[0].keys()):
    print(f'{k}\t{tokenized_trainset[0][k]}')

attention_mask	[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 

In [22]:
valset = load_dataset(DATASET_NAME, split='validation')

In [23]:
tokenized_valset = valset.map(tokenize_and_align_labels, batched=True)

In [24]:
data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

In [25]:
seqeval = evaluate.load('seqeval')

In [26]:
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = seqeval.compute(predictions=true_predictions, references=true_labels)
    return {
        'precision': results['overall_precision'],
        'recall': results['overall_recall'],
        'f1': results['overall_f1'],
        'accuracy': results['overall_accuracy'],
    }

In [27]:
id2label = dict(enumerate(label_list))

In [28]:
print(id2label)

{0: 'O', 1: 'B-Archive', 2: 'B-CelestialObject', 3: 'B-CelestialObjectRegion', 4: 'B-CelestialRegion', 5: 'B-Citation', 6: 'B-Collaboration', 7: 'B-ComputingFacility', 8: 'B-Database', 9: 'B-Dataset', 10: 'B-EntityOfFutureInterest', 11: 'B-Event', 12: 'B-Fellowship', 13: 'B-Formula', 14: 'B-Grant', 15: 'B-Identifier', 16: 'B-Instrument', 17: 'B-Location', 18: 'B-Mission', 19: 'B-Model', 20: 'B-ObservationalTechniques', 21: 'B-Observatory', 22: 'B-Organization', 23: 'B-Person', 24: 'B-Proposal', 25: 'B-Software', 26: 'B-Survey', 27: 'B-Tag', 28: 'B-Telescope', 29: 'B-TextGarbage', 30: 'B-URL', 31: 'B-Wavelength', 32: 'I-Archive', 33: 'I-CelestialObject', 34: 'I-CelestialObjectRegion', 35: 'I-CelestialRegion', 36: 'I-Citation', 37: 'I-Collaboration', 38: 'I-ComputingFacility', 39: 'I-Database', 40: 'I-Dataset', 41: 'I-EntityOfFutureInterest', 42: 'I-Event', 43: 'I-Fellowship', 44: 'I-Formula', 45: 'I-Grant', 46: 'I-Identifier', 47: 'I-Instrument', 48: 'I-Location', 49: 'I-Mission', 50: '

In [29]:
label2id = dict((val, idx) for idx, val in enumerate(label_list))


In [30]:
print(label2id)


{'O': 0, 'B-Archive': 1, 'B-CelestialObject': 2, 'B-CelestialObjectRegion': 3, 'B-CelestialRegion': 4, 'B-Citation': 5, 'B-Collaboration': 6, 'B-ComputingFacility': 7, 'B-Database': 8, 'B-Dataset': 9, 'B-EntityOfFutureInterest': 10, 'B-Event': 11, 'B-Fellowship': 12, 'B-Formula': 13, 'B-Grant': 14, 'B-Identifier': 15, 'B-Instrument': 16, 'B-Location': 17, 'B-Mission': 18, 'B-Model': 19, 'B-ObservationalTechniques': 20, 'B-Observatory': 21, 'B-Organization': 22, 'B-Person': 23, 'B-Proposal': 24, 'B-Software': 25, 'B-Survey': 26, 'B-Tag': 27, 'B-Telescope': 28, 'B-TextGarbage': 29, 'B-URL': 30, 'B-Wavelength': 31, 'I-Archive': 32, 'I-CelestialObject': 33, 'I-CelestialObjectRegion': 34, 'I-CelestialRegion': 35, 'I-Citation': 36, 'I-Collaboration': 37, 'I-ComputingFacility': 38, 'I-Database': 39, 'I-Dataset': 40, 'I-EntityOfFutureInterest': 41, 'I-Event': 42, 'I-Fellowship': 43, 'I-Formula': 44, 'I-Grant': 45, 'I-Identifier': 46, 'I-Instrument': 47, 'I-Location': 48, 'I-Mission': 49, 'I-Mo

In [31]:
model = AutoModelForTokenClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(label_list),
    id2label=id2label,
    label2id=label2id
)

Some weights of XLMRobertaForTokenClassification were not initialized from the model checkpoint at FacebookAI/xlm-roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [32]:
MODEL_NAME_ON_DISK = os.path.abspath('astro_ner')

In [33]:
training_args = TrainingArguments(
    output_dir=MODEL_NAME_ON_DISK,
    logging_dir=os.path.join(MODEL_NAME_ON_DISK, 'logs'),
    learning_rate=1e-4,
    warmup_ratio=0.5,  # делаем "прогрев": начинаем с околонулевого lr и до середины обучения (то есть до пятой эпохи, если у нас их 10) линейно увеличиваем до 1e-4
    per_device_train_batch_size=4,
    per_device_eval_batch_size=8,  # при тестировании не считаются градиенты, поэтому мини-батч можно и побольше
    num_train_epochs=10,
    weight_decay=0.01,  # для регуляризации обновлений весов
    eval_strategy='epoch',
    save_strategy='epoch',
    save_total_limit=2,  # чтобы не забивать жёсткий диск, будем сохранять только 2 чекпойнта: лучший и последний
    logging_strategy='epoch',
    report_to='tensorboard',  # хотим красиво нарисовать графики обучения в tensorboard
    metric_for_best_model='f1',
    greater_is_better=True,  # чем больше f1, тем лучше
    load_best_model_at_end=True,
    seed=RANDOM_SEED,
    data_seed=RANDOM_SEED
)

In [34]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_trainset,
    eval_dataset=tokenized_valset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

  trainer = Trainer(


In [None]:
trainer.train()

Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,1.044,0.282437,0.578831,0.62591,0.601451,0.937679
2,0.2161,0.189834,0.76801,0.702469,0.733779,0.95406
3,0.1449,0.158551,0.721534,0.771633,0.745743,0.957966
4,0.1159,0.167966,0.769022,0.763662,0.766332,0.959645
5,0.102,0.168885,0.723091,0.77664,0.748909,0.957903
6,0.0856,0.164122,0.757171,0.788809,0.772666,0.960168


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [35]:
possible_checkpoints = sorted(
    list(map(
        lambda it2: os.path.join(MODEL_NAME_ON_DISK, it2),
        filter(
            lambda it1: it1.startswith('checkpoint-'),
            os.listdir(MODEL_NAME_ON_DISK)
        )
    )),
    key=lambda it3: -len(os.listdir(it3))
)

In [36]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [37]:
for it in possible_checkpoints: print(it)

C:\proj\nlp\lab2\astro_ner\checkpoint-2200
C:\proj\nlp\lab2\astro_ner\checkpoint-2634


In [38]:
classifier = pipeline('ner', model=possible_checkpoints[0], device=0)

In [39]:
original_text = 'The authors would like to thank Adam Burgasser, Brendan Bowler, Kelle Cruz, Mike Cushing, Michael Liu, and Emily Rice for useful discussions on benchmark systems, data treatment, and various data-model comparison approaches. The authors thank Richard Freedman and Roxana Lupu for providing gas opacities and Caroline Morley for radiative transfer code comparisons and helpful discussions. We thank Jacob Lustig-Yeager and Kyle Luther for rewriting portions of the code in python and C for significant speed improvements and also Dan Foreman-Mackey for making EMCEE available to the community. Finally, we thank the anonymous referee and statistics consultant for useful and insightful comments.'

In [44]:
original_res = classifier(original_text, aggregation_strategy='first')

In [45]:
print(original_res)

[{'entity_group': 'Person', 'score': 0.9994726, 'word': 'AdamBurgasser,', 'start': 32, 'end': 47}, {'entity_group': 'Person', 'score': 0.99946135, 'word': 'BrendanBowler,', 'start': 48, 'end': 63}, {'entity_group': 'Person', 'score': 0.99942684, 'word': 'KelleCruz,', 'start': 64, 'end': 75}, {'entity_group': 'Person', 'score': 0.99946856, 'word': 'MikeCushing,', 'start': 76, 'end': 89}, {'entity_group': 'Person', 'score': 0.99952126, 'word': 'MichaelLiu,', 'start': 90, 'end': 102}, {'entity_group': 'Person', 'score': 0.99944293, 'word': 'EmilyRice', 'start': 107, 'end': 117}, {'entity_group': 'Person', 'score': 0.9994572, 'word': 'RichardFreedman', 'start': 243, 'end': 259}, {'entity_group': 'Person', 'score': 0.9994682, 'word': 'RoxanaLupu', 'start': 264, 'end': 275}, {'entity_group': 'Person', 'score': 0.99949706, 'word': 'CarolineMorley', 'start': 308, 'end': 323}, {'entity_group': 'Person', 'score': 0.99907404, 'word': 'JacobLustig-Yeager', 'start': 398, 'end': 417}, {'entity_group

In [42]:
ru_text = 'Авторы хотели бы поблагодарить Адама Бургассера, Брендана Боулера, Келли Круз, Майка Кушинга, Майкла Лю и Эмили Райс за полезные обсуждения систем бенчмарков, обработки данных и различных подходов к сравнению моделей данных. Авторы благодарят Ричарда Фридмана и Роксану Лупу за предоставление непрозрачности газа и Кэролайн Морли за сравнения кодов переноса излучения и полезные обсуждения. Мы благодарим Джейкоба Люстига-Йегера и Кайла Лютера за переписывание частей кода на Python и C для значительного улучшения скорости, а также Дэна Формана-Макки за предоставление EMCEE сообществу. Наконец, мы благодарим анонимного рецензента и консультанта по статистике за полезные и проницательные комментарии.'

In [43]:
ru_res = classifier(ru_text, aggregation_strategy='first')
for it in ru_res: print(it)

{'entity_group': 'Person', 'score': 0.99940604, 'word': 'АдамаБургассера,', 'start': 31, 'end': 48}
{'entity_group': 'Person', 'score': 0.9994503, 'word': 'БренданаБоулера,', 'start': 49, 'end': 66}
{'entity_group': 'Person', 'score': 0.99946344, 'word': 'КеллиКруз,', 'start': 67, 'end': 78}
{'entity_group': 'Person', 'score': 0.99943066, 'word': 'МайкаКушинга,', 'start': 79, 'end': 93}
{'entity_group': 'Person', 'score': 0.9993192, 'word': 'МайклаЛю', 'start': 94, 'end': 103}
{'entity_group': 'Person', 'score': 0.99934244, 'word': 'ЭмилиРайс', 'start': 106, 'end': 116}
{'entity_group': 'Person', 'score': 0.99936646, 'word': 'РичардаФридмана', 'start': 243, 'end': 259}
{'entity_group': 'Person', 'score': 0.9994545, 'word': 'РоксануЛупу', 'start': 262, 'end': 274}
{'entity_group': 'Person', 'score': 0.9994248, 'word': 'КэролайнМорли', 'start': 315, 'end': 329}
{'entity_group': 'Person', 'score': 0.9979596, 'word': 'ДжейкобаЛюстига-Йегера', 'start': 405, 'end': 428}
{'entity_group': 'Per

