<a href="https://colab.research.google.com/github/github-chx/experiment/blob/main/base%20/preprocess_ner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from transformers import AutoTokenizer

In [None]:
sentence = "play music on visha"
# word-type-start-end
entities = [("visha", "AppName", 14, 19)]

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")


In [None]:
tokens = tokenizer(sentence, return_offsets_mapping=True)
tokens

{'input_ids': [101, 2377, 2189, 2006, 25292, 3270, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1], 'offset_mapping': [(0, 0), (0, 4), (5, 10), (11, 13), (14, 17), (17, 19), (0, 0)]}

In [None]:
words = tokenizer.convert_ids_to_tokens(tokens["input_ids"])
words

['[CLS]', 'play', 'music', 'on', 'vis', '##ha', '[SEP]']

In [45]:
def get_entities_by_token(entity_list, token_offsets):
  entities_by_token = []
  edge_mismatch = False
  start2id = {}
  end2id = {}
  for i, (s, e) in enumerate(token_offsets):
    if e == 0:
      continue
    start2id[s] = i
    end2id[e] = i+1

  for w, t, s, e in entity_list:
    sid = start2id.get(s, 0)
    eid = end2id.get(e, 0)
    if sid and eid and (eid > sid):
      entities_by_token.append((w, t, sid, eid))
    else:
      edge_mismatch = True
    print(w, t, sid, eid)
  return entities_by_token, edge_mismatch


class TrainingData:
  def __init__(self, sub_words, tokens, entities):
    self.sub_words = sub_words
    self.input_ids = tokens["input_ids"]
    self.length = len(tokens["input_ids"])
    self.token_type_ids = tokens["token_type_ids"]
    self.attention_mask = tokens["attention_mask"]
    self.offset_mapping = tokens["offset_mapping"]
    self.entities_tokens, self.edge_match = get_entities_by_token(entities, tokens["offset_mapping"])

  def get_tags(self, max_len):
    tags = ["O" for _ in range(max_len)]
    tags[0] = "START_TAG"
    tags[self.length-1] = "END_TAG"
    for w, t, s, e in self.entities_tokens:
      tags[s] = f'B-{t}'
      for j in range(s+1, e):
        tags[j] = f"I-{t}"
    return tags


train_data = TrainingData(
  sub_words=words,
  tokens=tokens,
  entities=entities,
)

tags = train_data.get_tags(max_len=10)
for w, t in zip(words, tags):
  print(w, t)

visha AppName 4 6
[CLS] START_TAG
play O
music O
on O
vis B-AppName
##ha I-AppName
[SEP] END_TAG
