In [1]:
import re

import spacy
from spacy import displacy
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification

In [2]:
MODEL_NAME = "tner/roberta-large-ontonotes5"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME).to("mps")

  return self.fget.__get__(instance, owner)()


In [4]:
spacy_pipeline = spacy.load("en_core_web_sm", exclude=["tagger", "parser", "ner", "lemmatizer", "tok2vec"])

In [5]:
sample_text = """Existing additively manufactured aluminum alloys exhibit poor creep resistance due to coarsening of their strengthening phases and refined grain structures. In this paper, we report on a novel additively manufactured Al-10.5Ce-3.1Ni-1.2Mn wt.% alloy which displays excellent creep resistance relative to cast high-temperature aluminum alloys at 300–400°C. The creep resistance of this alloy is attributed to a high volume fraction (∼35%) of submicron intermetallic strengthening phases which are coarsening-resistant for hundreds of hours at 350°C. The results herein demonstrate that additive manufacturing provides opportunities for development of creep-resistant aluminum alloys that may be used in bulk form in the 250–400°C temperature range. Pathways for further development of such alloys are identified."""

In [6]:
tokenized_dict = tokenizer(sample_text, return_tensors="pt", return_offsets_mapping=True)
offset_mapping = torch.squeeze(tokenized_dict.offset_mapping) 

model_output = model(input_ids=tokenized_dict.input_ids.to("mps"))

In [7]:
label_idxs = torch.squeeze(torch.argmax(model_output.logits, dim=-1)).tolist()
label_list = [model.config.id2label[idx] for idx in label_idxs] 

In [8]:
spacy_doc = spacy_pipeline(sample_text)

In [23]:
SKIP_LABELS = ["O"]

annotations_list = []

current_annotation = None

for label, (offset_start, offset_end) in zip(label_list, offset_mapping.tolist()):
    cleaned_label = re.sub("[BIO]-", "", label)
    if current_annotation is None:
        current_annotation = (cleaned_label, offset_start, offset_end)
        continue
    elif cleaned_label != current_annotation[0]:
        if current_annotation[0] not in SKIP_LABELS:
            annotations_list.append(current_annotation)
        current_annotation = (cleaned_label, offset_start, offset_end)
    elif cleaned_label == current_annotation[0]:
        current_annotation = (current_annotation[0], current_annotation[1], offset_end)
    else:
        raise AssertionError   
    

spacy_spans = []
for e_type, start_char, end_char in annotations_list:
    span = spacy_doc.char_span(start_char, end_char, label=e_type, alignment_mode="expand")
    if span is not None:
        spacy_spans.append(span)
    
spacy_doc.spans["sc"] = spacy_spans

In [24]:
displacy.render(spacy_doc, style="span")