# BERTModel 

In [83]:
from transformers import AutoModelForTokenClassification, AutoTokenizer
import torch
from promptedgraphs.models import EntityReference
from typing import Dict, List
from collections import defaultdict
import re
import re
from promptedgraphs.vis import render_entities

# Load the BERT model and tokenizer
model_name = "dbmdz/bert-large-cased-finetuned-conll03-english"
model = AutoModelForTokenClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)


async def extract_entities_bert(
    text: str, labels: Dict[str, str]
) -> List[EntityReference]:
    # Tokenize the text and convert to tensor
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)

    # Predict entities using BERT
    with torch.no_grad():
        outputs = model(**inputs).logits
    predictions = torch.argmax(outputs, dim=2)

    # Map predictions to entity labels
    tokenized_text = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    entities = defaultdict(list)

    prev_token_label = "0"
    start_char = 0
    for token, prediction in zip(tokenized_text, predictions[0].numpy()):
        label = model.config.id2label[prediction]
        text_span = tokenizer.convert_tokens_to_string([token])
        if label != "O":  # O means no entity
            if prev_token_label == label:  # combine continuous labels
                entities[label][-1]["text"] += " " + token
                entities[label][-1]["end"] += len(text_span) + 1
            else:
                entities[label].append(
                    {
                        "text": token,
                        "start": start_char,
                        "end": start_char + len(text_span),
                    }
                )
        prev_token_label = label
        if token not in {"[CLS]", "[SEP]"}:
            start_char += len(text_span) + 1

    # convert to EntityReference
    processed_entities = []
    for label, tokens in entities.items():
        for m in tokens:
            for match in re.finditer(m["text"], text):
                entity = EntityReference(
                    start=match.start(),
                    end=match.end(),
                    text=m["text"],
                    label=label,
                )
                processed_entities.append(entity)

    return processed_entities

Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [84]:
text = "I am in Kansas, my favorite actor is Matt Damon and I live in North Dakota"
entities = await extract_entities_bert(text, labels=labels)

In [85]:
render_entities(text, entities)

In [86]:
entities

[EntityReference(start=8, end=14, label='I-LOC', text='Kansas', reason=None),
 EntityReference(start=62, end=74, label='I-LOC', text='North Dakota', reason=None),
 EntityReference(start=37, end=47, label='I-PER', text='Matt Damon', reason=None)]