In [50]:
from dataclasses import dataclass
import re

import spacy
from spacy import displacy
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification

In [2]:
MODEL_NAME = "tner/roberta-large-ontonotes5"

In [3]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME).to("mps")

  return self.fget.__get__(instance, owner)()


In [4]:
spacy_pipeline = spacy.load("en_core_web_sm", exclude=["tagger", "parser", "ner", "lemmatizer", "tok2vec"])

In [26]:
sample_texts = [
"""Existing additively manufactured aluminum alloys exhibit poor creep resistance due to coarsening of their strengthening phases and refined grain structures.""",
"""In this paper, we report on a novel additively manufactured Al-10.5Ce-3.1Ni-1.2Mn wt.% alloy which displays excellent creep resistance relative to cast high-temperature aluminum alloys at 300–400°C.""",
"""The creep resistance of this alloy is attributed to a high volume fraction (∼35%) of submicron intermetallic strengthening phases which are coarsening-resistant for hundreds of hours at 350°C.""",
"""The results herein demonstrate that additive manufacturing provides opportunities for development of creep-resistant aluminum alloys that may be used in bulk form in the 250–400°C temperature range.""", 
"""Pathways for further development of such alloys are identified.""", 
]

In [65]:
@dataclass
class EntityCharSpan:
    e_type: str
    start_char: int
    end_char: int
    
def get_char_spans_from_labels(label_list: list[str], offset_mapping: list[list[int]], SKIP_LABELS=("O")) -> list[EntityCharSpan]:
    annotations_list = []
    current_annotation = None
    
    for label, (offset_start, offset_end) in zip(label_list, offset_mapping):
        cleaned_label = re.sub("[BIO]-", "", label)
        if current_annotation is None:
            current_annotation = EntityCharSpan(e_type=cleaned_label, start_char=offset_start, end_char=offset_end)
            continue
        elif cleaned_label != current_annotation.e_type:
            if current_annotation.e_type not in SKIP_LABELS:
                annotations_list.append(current_annotation)
            current_annotation = EntityCharSpan(e_type=cleaned_label, start_char=offset_start, end_char=offset_end)
        elif cleaned_label == current_annotation.e_type:
            current_annotation.end_char = offset_end
        else:
            raise AssertionError("Unexpected case!!")
    return annotations_list
    

def tag_entities(sentences:list[str], tokenizer, model):
    tokenized = tokenizer(sentences, return_tensors="pt", padding=True, return_offsets_mapping=True, return_attention_mask=True)
    offset_mapping = tokenized.offset_mapping.tolist()
    model_output = model(input_ids=tokenized.input_ids.to("mps"), attention_mask=tokenized.attention_mask.to("mps"))
    label_idxs = torch.argmax(model_output.logits, dim=-1).tolist()
    label_lists = [[model.config.id2label[idx] for idx, attention_value in zip(label_list, attention_mask) if attention_value == 1] for (label_list, attention_mask) in zip(label_idxs, tokenized.attention_mask)]
    entity_char_spans = [get_char_spans_from_labels(label_list, instance_offset_mapping) for (label_list, instance_offset_mapping) in zip(label_lists, offset_mapping)]
    return entity_char_spans

all_annotations = tag_entities(sample_texts[:2], tokenizer, model)

In [66]:
spacy_docs = [spacy_pipeline(text) for text in sample_texts]

In [70]:
spacy_spans = []
for e in all_annotations[1]:
    span = spacy_docs[1].char_span(e.start_char, e.end_char, label=e.e_type, alignment_mode="expand")
    if span is not None:
        spacy_spans.append(span)
    
spacy_docs[1].spans["sc"] = spacy_spans

In [71]:
displacy.render(spacy_docs[1], style="span")