<a href="https://colab.research.google.com/github/git-grace/experiment/blob/main/base/preprocess_ner.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from transformers import AutoTokenizer

In [None]:
sentence = "i want watch movies on visha"
# word-type-start-end
entities = [("visha", "App", 23, 28)]

In [None]:
sentence[23:28]

'visha'

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")


In [None]:
tokens = tokenizer(sentence, return_offsets_mapping=True)
tokens

{'input_ids': [101, 1045, 2215, 3422, 5691, 2006, 25292, 3270, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1], 'offset_mapping': [(0, 0), (0, 1), (2, 6), (7, 12), (13, 19), (20, 22), (23, 26), (26, 28), (0, 0)]}

In [None]:
words = tokenizer.convert_ids_to_tokens(tokens["input_ids"])
words

['[CLS]', 'i', 'want', 'watch', 'movies', 'on', 'vis', '##ha', '[SEP]']

In [None]:
def build_entities_by_token(entity_list, token_offsets):
  entities_by_token = []
  edge_mismatch = False
  start2id = {}
  end2id = {}
  for i, (s, e) in enumerate(token_offsets):
    if e == 0:
      continue
    start2id[s] = i
    end2id[e] = i+1

  print("start2id: ", start2id, "end2id: ", end2id)
  for w, t, s, e in entity_list:
    sid = start2id.get(s, 0)
    eid = end2id.get(e, 0)
    if sid and eid and (eid > sid):
      entities_by_token.append((w, t, sid, eid))
    else:
      edge_mismatch = True
    print(w, t, sid, eid)
  return entities_by_token, edge_mismatch


class TrainingData:
  def __init__(self, sub_words, tokens, entities):
    self.sub_words = sub_words
    self.input_ids = tokens["input_ids"]
    self.length = len(tokens["input_ids"])
    self.token_type_ids = tokens["token_type_ids"]
    self.attention_mask = tokens["attention_mask"]
    self.offset_mapping = tokens["offset_mapping"]
    self.entities_tokens, self.edge_match = build_entities_by_token(entities, tokens["offset_mapping"])

  def get_tags(self, max_len):
    tags = ["O" for _ in range(max_len)]
    tags[0] = "START_TAG"
    tags[self.length-1] = "END_TAG"
    for w, t, s, e in self.entities_tokens:
      tags[s] = f'B-{t}'
      for j in range(s+1, e):
        tags[j] = f"I-{t}"
    return tags


train_data = TrainingData(
  sub_words=words,
  tokens=tokens,
  entities=entities,
)

tags = train_data.get_tags(max_len=10)
for w, t in zip(words, tags):
  print(w, t)

start2id:  {0: 1, 2: 2, 7: 3, 13: 4, 20: 5, 23: 6, 26: 7} end2id:  {1: 2, 6: 3, 12: 4, 19: 5, 22: 6, 26: 7, 28: 8}
visha App 6 8
[CLS] START_TAG
i O
want O
watch O
movies O
on O
vis B-App
##ha I-App
[SEP] END_TAG


In [None]:
%pip install pytorch-crf

Collecting pytorch-crf
  Downloading pytorch_crf-0.7.2-py3-none-any.whl (9.5 kB)
Installing collected packages: pytorch-crf
Successfully installed pytorch-crf-0.7.2


In [None]:
import os
import math
import torch
import torch.nn as nn
from torchcrf import CRF
from itertools import repeat
from transformers import BertModel

In [None]:
class CRFModel(nn.Module):
  def __init__(self, bert_dir, num_tags, dropout_prob=0.1, **kwargs):
    super(CRFModel, self).__init__()

    self.bert_module = BertModel.from_pretrained(bert_dir)

    out_dims = kwargs.pop("hidden_size", 768)
    mid_linear_dims = kwargs.pop('mid_linear_dims', 128)

    self.mid_linear = nn.Sequential(
      nn.Linear(out_dims, mid_linear_dims),
      nn.ReLU(),
      nn.Dropout(dropout_prob)
    )

    out_dims = mid_linear_dims
    self.classifier = nn.Linear(out_dims, num_tags)

    self.loss_weight = nn.Parameter(torch.FloatTensor(1), requires_grad=True)
    self.loss_weight.data.fill_(-0.2)

    self.crf_module = CRF(num_tags=num_tags, batch_first=True)

    init_blocks = [self.mid_linear, self.classifier]

    self._init_weights(init_blocks, initializer_range=self.bert_config.initializer_range)

  def forward(self,
            token_ids,
            attention_masks,
            token_type_ids,
            labels=None,
            pseudo=None):

    bert_outputs = self.bert_module(input_ids=token_ids, attention_mask=attention_masks, token_type_ids=token_type_ids)

    # 常规
    seq_out = bert_outputs[0]

    seq_out = self.mid_linear(seq_out)

    emissions = self.classifier(seq_out)

    if labels is not None:
      tokens_loss = -1. * self.crf_module(emissions=emissions,
                          tags=labels.long(),
                          mask=attention_masks.byte(),
                          reduction='mean')

      out = (tokens_loss,)

    else:
      tokens_out = self.crf_module.decode(emissions=emissions, mask=attention_masks.byte())

      out = (tokens_out, emissions)

    return out

In [None]:
class InputExample:
  def __init__(self,
        set_type,
        text,
        labels=None):
    self.set_type = set_type
    self.text = text
    self.labels = labels

def convert_crf_example(ex_idx, example: InputExample, tokenizer, max_seq_len, ent2id):
  set_type = example.set_type
  raw_text = example.text
  entities = example.labels
  pseudo = example.pseudo

  callback_info = (raw_text,)

  tokens = fine_grade_tokenize(raw_text, tokenizer)
  assert len(tokens) == len(raw_text)

  label_ids = None

  if set_type == 'train':
    # information for dev callback
    label_ids = [0] * len(tokens)

    # tag labels  ent ex. (T1, DRUG_DOSAGE, 447, 450, 小蜜丸)
    for ent in entities:
        ent_type = ent[0]

        ent_start = ent[-1]
        ent_end = ent_start + len(ent[1]) - 1

        if ent_start == ent_end:
            label_ids[ent_start] = ent2id['S-' + ent_type]
        else:
            label_ids[ent_start] = ent2id['B-' + ent_type]
            label_ids[ent_end] = ent2id['E-' + ent_type]
            for i in range(ent_start + 1, ent_end):
                label_ids[i] = ent2id['I-' + ent_type]

    if len(label_ids) > max_seq_len - 2:
        label_ids = label_ids[:max_seq_len - 2]

    label_ids = [0] + label_ids + [0]

    # pad
    if len(label_ids) < max_seq_len:
        pad_length = max_seq_len - len(label_ids)
        label_ids = label_ids + [0] * pad_length  # CLS SEP PAD label都为O

    assert len(label_ids) == max_seq_len, f'{len(label_ids)}'

  encode_dict = tokenizer.encode_plus(text=tokens,
                                      max_length=max_seq_len,
                                      pad_to_max_length=True,
                                      is_pretokenized=True,
                                      return_token_type_ids=True,
                                      return_attention_mask=True)

  token_ids = encode_dict['input_ids']
  attention_masks = encode_dict['attention_mask']
  token_type_ids = encode_dict['token_type_ids']

  # if ex_idx < 3:
  #     logger.info(f"*** {set_type}_example-{ex_idx} ***")
  #     logger.info(f'text: {" ".join(tokens)}')
  #     logger.info(f"token_ids: {token_ids}")
  #     logger.info(f"attention_masks: {attention_masks}")
  #     logger.info(f"token_type_ids: {token_type_ids}")
  #     logger.info(f"labels: {label_ids}")

  feature = CRFFeature(
      # bert inputs
      token_ids=token_ids,
      attention_masks=attention_masks,
      token_type_ids=token_type_ids,
      labels=label_ids,
      pseudo=pseudo
  )

  return feature, callback_info