In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForTokenClassification, DataCollatorForTokenClassification, Trainer, TrainingArguments

In [None]:
ds = load_dataset("qgyd2021/chinese_ner_sft", "CMeEE", trust_remote_code=True) 
ds

In [None]:
ds = ds["train"].train_test_split(test_size=10)
ds

In [None]:
sample_dataset = ds["train"].select(range(10))
sample_dataset

In [None]:
sample_dataset[0]

In [6]:
a = sample_dataset[0]

In [None]:
text = a["text"]
text

In [None]:
entities = a["entities"]
entities

In [9]:
start_idx = entities["start_idx"]
end_idx = entities["end_idx"]
entity_text = entities["entity_text"]
text = a["text"]

for start, end, true_text in zip(start_idx, end_idx, entity_text):
    label_text = text[start:end+1]
    assert label_text == true_text

In [10]:
def getLabel(examples):
    label2name = {}
    for item in examples:
        entities = item["entities"]
        entity_label = entities["entity_label"]
        entity_name = entities["entity_names"]
        
        for label, name in zip(entity_label, entity_name):
            if label not in label2name:
                label2name[label] = name
            else:
                assert label2name[label] == name, f"{label2name[label]} != {name}"
                
                
    id2label = {}
    label2id = {}

    index = 0
    for label_org in label2name:
        label = "B-" + label_org
        id2label[index] = label
        label2id[label] = index
        index += 1
        
        label = "I-" + label_org
        id2label[index] = label
        label2id[label] = index
        index += 1
    # 非实体    
    id2label[index] = "O"
    label2id["O"] = index
    
    return label2name, label2id, id2label

label2name, label2id, id2label = getLabel(ds["train"])

In [None]:
def getModelAndTokenizer(checkpoint):
    model = AutoModelForTokenClassification.from_pretrained(checkpoint, label2id=label2id, id2label=id2label)
    tokenizer = AutoTokenizer.from_pretrained(checkpoint)
    return model, tokenizer

checkpoint = "google-bert/bert-base-chinese"
model, tokenizer = getModelAndTokenizer(checkpoint)

In [None]:
text_list = []
for item in sample_dataset.select(range(2)):
    text = item["text"]
    text_list.append(text)
inputs = tokenizer(text_list, return_offsets_mapping=True)
inputs["offset_mapping"][0]

In [13]:
def process_function(examples):
    text_list = examples["text"]
    example_label_list = []
    entities = examples["entities"]
    for item in entities:
        start_idx = item["start_idx"]
        end_idx = item["end_idx"]
        entity_label = item["entity_label"]
        
        example_label_list.append((start_idx, end_idx, entity_label))
    
    assert len(example_label_list) == len(text_list)
    
    inputs = tokenizer(text_list, return_offsets_mapping=True, truncation=True, max_length=384)
    offset_mapping_list = inputs["offset_mapping"]
    
    total_label_list = []
        
    for offset_mapping, label in zip(offset_mapping_list, example_label_list):
        
        start_idx, end_idx, entity_label = label
        # 先针对位置做一个映射关系
        map_dict = {}
        for start, end, label in zip(start_idx, end_idx, entity_label):
            map_dict[start] = "B-" + label
            for i in range(start+1, end+1):
                map_dict[i] = "I-" + label
                
        label_list = []
        for offset in offset_mapping:
            start, end = offset
            if start == end:
                label_list.append(-100)
            else:
                if start not in map_dict:
                    label_list.append(label2id["O"])
                else:
                    label_list.append(label2id[map_dict[start]])
        assert len(label_list) == len(offset_mapping)
        total_label_list.append(label_list)
        
    inputs["labels"] = total_label_list
    return inputs
    
        

In [None]:
tokenizer_dataset = ds.map(process_function, batched=True)
tokenizer_dataset

In [15]:
args = TrainingArguments(
    "output/CMeEE",
    logging_steps=20,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32
)
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenizer_dataset["train"],
    eval_dataset=tokenizer_dataset["test"],
    processing_class=tokenizer,
    data_collator=DataCollatorForTokenClassification(tokenizer)
)

In [None]:
trainer.train()

In [17]:
from transformers import pipeline

In [None]:
pipe = pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple", ignore_labels=["O"])

In [None]:
test_data = ds["test"]
text = test_data[0]["text"]
entities = test_data[0]["entities"]
text, entities

In [None]:
pipe(text)