In [67]:
import json
from transformers import BertTokenizerFast

In [68]:
with open('CMeEE_train.json', encoding='utf-8') as f:
    data_raw = json.load(f)

print(data_raw)  # 列表

[{'text': '【病原和流行病学】狂犬病病毒（rabiesvirus）属弹状病毒科狂犬病病毒属。', 'entities': [{'start_idx': 9, 'end_idx': 13, 'type': 'mic', 'entity': '狂犬病病毒'}, {'start_idx': 15, 'end_idx': 25, 'type': 'mic', 'entity': 'rabiesvirus'}, {'start_idx': 28, 'end_idx': 31, 'type': 'mic', 'entity': '弹状病毒'}, {'start_idx': 33, 'end_idx': 37, 'type': 'mic', 'entity': '狂犬病病毒'}]}, {'text': '对儿童SARST细胞亚群的研究表明，与成人SARS相比，儿童细胞下降不明显，证明上述推测成立。', 'entities': [{'start_idx': 3, 'end_idx': 9, 'type': 'bod', 'entity': 'SARST细胞'}, {'start_idx': 19, 'end_idx': 24, 'type': 'dis', 'entity': '成人SARS'}]}, {'text': '研究证实，细胞减少与肺内病变程度及肺内炎性病变吸收程度密切相关。', 'entities': [{'start_idx': 10, 'end_idx': 10, 'type': 'bod', 'entity': '肺'}, {'start_idx': 10, 'end_idx': 13, 'type': 'sym', 'entity': '肺内病变'}, {'start_idx': 17, 'end_idx': 17, 'type': 'bod', 'entity': '肺'}, {'start_idx': 17, 'end_idx': 22, 'type': 'sym', 'entity': '肺内炎性病变'}]}]


In [69]:
data_raw[0]

{'text': '【病原和流行病学】狂犬病病毒（rabiesvirus）属弹状病毒科狂犬病病毒属。',
 'entities': [{'start_idx': 9,
   'end_idx': 13,
   'type': 'mic',
   'entity': '狂犬病病毒'},
  {'start_idx': 15, 'end_idx': 25, 'type': 'mic', 'entity': 'rabiesvirus'},
  {'start_idx': 28, 'end_idx': 31, 'type': 'mic', 'entity': '弹状病毒'},
  {'start_idx': 33, 'end_idx': 37, 'type': 'mic', 'entity': '狂犬病病毒'}]}

In [70]:
# 13表示结尾字符('毒')的位置
print(data_raw[0]['text'][13])
print(data_raw[0]['text'][25])
print(data_raw[0]['text'][31])
print(data_raw[0]['text'][37], end='\n\n')

# 字符串索引不包括结尾位置元素
print(data_raw[0]['text'][9:13 + 1])
print(data_raw[0]['text'][15:25 + 1])
print(data_raw[0]['text'][28:31 + 1])
print(data_raw[0]['text'][33:37 + 1])

毒
s
毒
毒

狂犬病病毒
rabiesvirus
弹状病毒
狂犬病病毒


In [71]:
tokenizer = BertTokenizerFast.from_pretrained('hfl/chinese-roberta-wwm-ext')
tokenizer

PreTrainedTokenizerFast(name_or_path='hfl/chinese-roberta-wwm-ext', vocab_size=21128, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [72]:
tokenizer.tokenize(data_raw[0]['text'])

['【',
 '病',
 '原',
 '和',
 '流',
 '行',
 '病',
 '学',
 '】',
 '狂',
 '犬',
 '病',
 '病',
 '毒',
 '（',
 'ra',
 '##bi',
 '##es',
 '##vi',
 '##rus',
 '）',
 '属',
 '弹',
 '状',
 '病',
 '毒',
 '科',
 '狂',
 '犬',
 '病',
 '病',
 '毒',
 '属',
 '。']

In [73]:
outputs = tokenizer([data_raw[0]['text']],
                    max_length=512, truncation=True, padding=True,
                    return_offsets_mapping=True)
offset_mapping = outputs["offset_mapping"]
offset_mapping  # return (char_start, char_end) for each token.

[[(0, 0),
  (0, 1),
  (1, 2),
  (2, 3),
  (3, 4),
  (4, 5),
  (5, 6),
  (6, 7),
  (7, 8),
  (8, 9),
  (9, 10),
  (10, 11),
  (11, 12),
  (12, 13),
  (13, 14),
  (14, 15),
  (15, 17),
  (17, 19),
  (19, 21),
  (21, 23),
  (23, 26),
  (26, 27),
  (27, 28),
  (28, 29),
  (29, 30),
  (30, 31),
  (31, 32),
  (32, 33),
  (33, 34),
  (34, 35),
  (35, 36),
  (36, 37),
  (37, 38),
  (38, 39),
  (39, 40),
  (0, 0)]]

In [74]:
# (0, 0)表示特殊token(如:'[CLS]','[SEP'], '[PAD]'等)
# i表示第几个token(从0开始计数,包含特殊token)
# j[1] - 1表示该token结尾字符的位置
start_mapping = [{j[0]: i for i, j in enumerate(i) if j != (0, 0)} for i in offset_mapping]
end_mapping = [{j[1] - 1: i for i, j in enumerate(i) if j != (0, 0)} for i in offset_mapping]
print(start_mapping)
print(end_mapping)

[{0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 12, 12: 13, 13: 14, 14: 15, 15: 16, 17: 17, 19: 18, 21: 19, 23: 20, 26: 21, 27: 22, 28: 23, 29: 24, 30: 25, 31: 26, 32: 27, 33: 28, 34: 29, 35: 30, 36: 31, 37: 32, 38: 33, 39: 34}]
[{0: 1, 1: 2, 2: 3, 3: 4, 4: 5, 5: 6, 6: 7, 7: 8, 8: 9, 9: 10, 10: 11, 11: 12, 12: 13, 13: 14, 14: 15, 16: 16, 18: 17, 20: 18, 22: 19, 25: 20, 26: 21, 27: 22, 28: 23, 29: 24, 30: 25, 31: 26, 32: 27, 33: 28, 34: 29, 35: 30, 36: 31, 37: 32, 38: 33, 39: 34}]


In [75]:
for i in data_raw[0]['entities']:
    start_idx, end_idx, type, entity = i['start_idx'], i['end_idx'], i['type'], i['entity']
    print(start_idx, end_idx, type, entity)
    if start_idx in start_mapping[0] and end_idx in end_mapping[0]:
        start_span = start_mapping[0][start_idx]
        end_span = end_mapping[0][end_idx]
        print(start_span, end_span)  # 该实体由第[start_span, end_span]的token组成(从0开始)

9 13 mic 狂犬病病毒
10 14
15 25 mic rabiesvirus
16 20
28 31 mic 弹状病毒
23 26
33 37 mic 狂犬病病毒
28 32
