In [1]:
import emoji
import argparse
import torch
import yaml

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from opencc import OpenCC
from pandarallel import pandarallel
from tqdm.auto import tqdm
from torch import cuda
from transformers import BertTokenizer, BertForTokenClassification

from langchain.text_splitter import CharacterTextSplitter

In [2]:
df_data = pd.read_excel("../data/打标结果expand.xlsx")
df_data["docid"] = df_data.apply(lambda x: x["doc_id"].split("_")[0], axis=1)
df_data = df_data[["docid", "content", 'human_brand', 'human_product']]
df_data.rename(columns={"content": "context"}, inplace=True)
df_data['doc_id'] = range(len(df_data))
df_data.head(2)

Unnamed: 0,docid,context,human_brand,human_product,doc_id
0,2024022321789920833,"霸王茶姬减脂红黑榜。CHAGEE霸王茶姬杀疯了!10元一杯的伯牙绝弦又来了。。正在减脂,但是...",['霸王茶姬'],"['伯牙绝弦', '桂馥兰香', '白雾红尘', '青青糯山', '去云南玫瑰普洱', '万...",0
1,2024022121789570136,独立小茶饼yyds❗️一次一片每天新口味。安利一款小茶礼!![哇R][哇R][哇R] 首先包...,['喜茶'],"['桂花红茶', '茉莉红茶', '百合红茶', '冰岛红茶', '柠檬红茶', '玫瑰红茶...",1


In [3]:
df_data['doc_id'].is_unique

True

In [4]:
human_product = df_data['human_product'].to_list()
human_product = [j for i in human_product if isinstance(i, str) for j in eval(i)]
with open("../dict/human_product.txt", "w") as f:
    f.write("\n".join(human_product))

In [5]:
len(human_product)

2264

In [6]:
def clean_data(example):
    text = emoji.replace_emoji(example, replace="")
    text = OpenCC('t2s.json').convert(text)
    text = ''.join([x for x in text if x.isprintable()])
    text = text.lower()
    return text

pandarallel.initialize(nb_workers=32, progress_bar=True, use_memory_fs=False)
df_data["context_cleaned"] = df_data.parallel_apply(lambda x: clean_data(x["context"]), axis=1)

INFO: Pandarallel will run on 32 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.


VBox(children=(HBox(children=(IntProgress(value=0, description='0.00%', max=10), Label(value='0 / 10'))), HBox…

In [7]:
df_data = df_data[df_data['context_cleaned'].str.len() <= 509]
df_data.reset_index(drop=True, inplace=True)

In [8]:
# def split_text(df, separator = '\n', chunksize = 500):
#     text_splitter = CharacterTextSplitter(
#         separator = separator,
#         chunk_size = chunksize,
#         chunk_overlap  = 0,
#         length_function = len
#     )
    
#     for idx, row in tqdm(df.iterrows(), desc='splitting'):
#         docid = str(row["docid"])
#         content = str(row["context_cleaned"])
#         print(len(content))
#         if content.strip(" \n") != "":
#             splitted_text = text_splitter.split_text(str(content))
#             for i, text in enumerate(splitted_text):
#                 docid = f"{docid}_{i}"
#                 df_split.loc[len(df_split)]= [docid, text]
#         else:
#             continue

In [9]:
# df_split["context_cleaned"].str.len().hist()

## utils

In [10]:
def remove_emoji(text):
    return emoji.replace_emoji(text, replace="")


def t2s(text):
    return OpenCC('t2s.json').convert(text)


def remove_inprintable(text):
    return ''.join([x for x in text if x.isprintable()])

In [11]:
def add_keywords_to_text(text, keyword_list):
    matched_keywords = []
    for keyword in keyword_list:
        if keyword in text:
            matched_keywords.append(keyword)

    labels = list(set(matched_keywords))
    labels_ = []
    for i in labels:
        temp = labels.copy()
        temp.remove(i)
        flag = 1
        for j in temp:
            if i not in j:
                continue
            else:
                flag = 0
                break
        if flag:
            labels_.append(i)
    
    text_list = []
    separator = "。！？#?!"
    start, end = 0, 0
    while end < len(text):
        if text[end] in separator:
            text_list.append(text[start:end+1])
            start, end = end + 1, end + 1
        else:
            end += 1
        
        if end == len(text) - 1:
            text_list.append(text[start:end+1])

    for idx, t in enumerate(text_list):
        temp = []
        for l in labels_:
            if l in t:
                temp.append(l)
        if temp != []:
            text_list[idx] = text_list[idx] + '||可能的实体：' + ",".join(temp) + '||' 
    return "".join(text_list)


In [12]:
def tokenize_single_text_and_add_masks(text, tokenizer):
    # tokenizer = BertTokenizer.from_pretrained("/ailab/src/models/chinese-roberta-wwm-ext_tokenizer")
    tokenized_text, tokenized_content = [], []
    headline_content = text
    for char in headline_content:
        if tokenizer.tokenize(char):
            tokenized_text.append(char)
            tokenized_content.extend(tokenizer.tokenize(char))
    
    tokenized_content = [tokenizer.cls_token] + tokenized_content + [tokenizer.sep_token]
    max_len = 512
    if len(tokenized_content) > max_len: 
        tokenized_content = tokenized_content[:max_len]
    else:
        tokenized_content = tokenized_content + [tokenizer.pad_token] * (max_len - len(tokenized_content))
    
    attn_mask = [1 if tok != tokenizer.pad_token else 0 for tok in tokenized_content]
    ids = tokenizer.convert_tokens_to_ids(tokenized_content)

    return tokenized_text, ids, attn_mask


In [13]:
def generate_single_input(text, keyword_list, tokenizer, add_keywords):
    """generate input given the text
    Args:
        arg: NameSpace, argumentation
        text: str, input text
    
    Returns:
        tokenized_content, ids, attn_mask: List
    """
    text = remove_emoji(text)
    text = t2s(text)
    text = remove_inprintable(text)

    if add_keywords:
        text = add_keywords_to_text(text, keyword_list)
    print("input text:", text, "\n")

    tokenized_text, ids, attn_mask = tokenize_single_text_and_add_masks(text, tokenizer)
    ids, attn_mask = torch.tensor([ids], dtype=torch.long), torch.tensor([attn_mask], dtype=torch.long)

    return tokenized_text, ids, attn_mask


In [14]:
def convert_labels_to_entity(entity_type_list, tokenized_text, label_pred):
    output = {label: [] for label in entity_type_list}
    tokenized_words = tokenized_text
    label_pred = label_pred
    start_idx, end_idx = -1, -1
    for idx, label in enumerate(label_pred):
        if label == 'O':
            if start_idx != -1:
                end_idx = idx
                # print(idx, entity_type, tokenized_words[start_idx:end_idx])
                output[entity_type].append("".join(tokenized_words[start_idx:end_idx]))
                start_idx, end_idx = -1, -1
            continue
        elif "B" in label:
            if start_idx != -1:
                end_idx = idx
                # print(idx, entity_type, tokenized_words[start_idx:end_idx])
                output[entity_type].append("".join(tokenized_words[start_idx:end_idx]))
                start_idx = idx
                entity_type = label.strip("B-")
            else:
                start_idx = idx
                entity_type = label.strip("B-")
        elif "I" in label:   
            if start_idx != -1:
                if entity_type != label.strip("I-"):
                    end_idx = idx
                    # print(idx, entity_type, tokenized_words[start_idx:end_idx])
                    output[entity_type].append("".join(tokenized_words[start_idx:end_idx]))
                    start_idx, end_idx = -1, -1
                else:
                    continue
            else:
                continue
    if start_idx != -1:
        end_idx = idx + 1
        output[entity_type].append("".join(tokenized_words[start_idx:end_idx]))
    output = {k: list(set([i for i in v if len(i)>1])) for (k, v) in output.items()}
    return output


In [15]:
def find_entity_index(text, entity):
    indices = []
    start = 0
    while True:
        start_index = text.find(entity, start)
        if start_index == -1:
            break
        end_index = start_index + len(entity)
        
        indices.append((start_index, end_index))
        start = end_index
    return indices

def get_entity_index(text, pred_entity):
    details = []
    for entity_type, entity_list in pred_entity.items():
        for entity in entity_list:
            indices = find_entity_index(text, entity)
            details.append(
                {
                    "entity_type": entity_type, 
                    "entity": entity, 
                    "index": indices
                })
    return details


## predict

In [16]:
keyword_path = "../dict/keyword_list_0314.txt"

In [17]:
class BERT_NER():
    def __init__(self, args):
        ## setting
        self.entity_types = ["brand", "product"]
        self.labels_to_ids = {'O': 0, 'B-brand': 1, 'I-brand': 2, 'B-product': 3, 'I-product': 4}
        self.ids_to_labels = {0: 'O', 1: 'B-brand', 2: 'I-brand', 3: 'B-product', 4: 'I-product'}
        self.device = 'cuda' if cuda.is_available() else 'cpu'
        self.num_labels = args.num_labels
        self.max_len = args.max_len

        ## keyword list
        with open(keyword_path, "r") as f:
            self.keyword_list = f.read().split("\n")
        
        ## tokenizer
        self.tokenizer = BertTokenizer.from_pretrained("hfl/chinese-roberta-wwm-ext")
        
        ## model
        self.model_path = f"{args.prefix_path}/model.pth"
        self.model = BertForTokenClassification.from_pretrained("hfl/chinese-roberta-wwm-ext", num_labels=self.num_labels)
        self.model.load_state_dict(torch.load(self.model_path))
        self.model.to(self.device)
    

    def predict_single(self, input_data):
        # print(input_data)
        docid = input_data['docid']
        text = input_data['content']
        add_keywords = input_data.get("add_keywords", True)
        output_dict = {"docid": docid, "entity":{}}

        tokenized_text, ids, masks = generate_single_input(text, self.keyword_list, self.tokenizer, add_keywords)
        ids, masks = ids.to(self.device), masks.to(self.device)

        output = self.model(input_ids=ids, attention_mask=masks, return_dict=True)
        output = output['logits']

        pred_labels = torch.argmax(output, dim=2).cpu().tolist() # (batchsize, seq_len)
        pred_labels = [self.ids_to_labels[j] for i in pred_labels for j in i] # list

        tokens = [self.tokenizer.convert_ids_to_tokens(i) for i in ids.squeeze().cpu().tolist()]

        p_l_temp = []
        for pair in zip(tokens, pred_labels):
            if pair[0] in ['[CLS]', '[SEP]', '[PAD]']:
                continue
            else:
                p_l_temp.append(pair[1])

        pred_entity = convert_labels_to_entity(self.entity_types, tokenized_text, p_l_temp)
        output_dict["entity"].update(pred_entity)

        ### add index
        details = get_entity_index(text, pred_entity)
        output_dict.update({"details": details})

        return output_dict

    # ========================== utils ==========================
    def to_numpy(self, tensor):
        return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

In [18]:
with open("../config.yaml", 'r') as config_file:
    config = yaml.safe_load(config_file)

In [19]:
parser = argparse.ArgumentParser(description='') 

parser.set_defaults(**config)
args = parser.parse_args(args = [])

parser.add_argument("--prefix_path", type=str, default=f"../experiments/{args.date}_{args.experiment_name}_{args.mode}")
args = parser.parse_args(args = [])

args

Namespace(prefix_path='../experiments/0315_bert_p=0.5_pos=sentence_BP', date='0315', experiment_name='bert_p=0.5_pos=sentence', mode='BP', entity_type=['brand', 'product'], num_labels=5, threshold=0.5, train_data='/train_dataset', valid_data='/valid_dataset', test_data='/test_dataset', pretrained_model='hfl/chinese-roberta-wwm-ext', do_training=True, max_len=512, batch_size=16, n_epochs=10, lr='1e-5', do_prediction=True)

In [20]:
input_list = [
    {
        "docid": row["doc_id"],
        "content": row["context_cleaned"],
        "add_keywords": True
    }
    for idx, row in tqdm(df_data.iterrows())
]

0it [00:00, ?it/s]

In [21]:
text = """记录一下喜茶每杯单品的最高热量和最低热量。减脂期间发现喜茶的热量计算器可以测算不同糖度以及冰/热带来的热量差异了,汇总一下(一切以小xc官方为准),单位kcal 🌟牛乳茶系列: 春光(四季春茶底):最低140,最高270 小奶茉(绿妍茶底):最低165,最高305 繁花(白兰蒙顶绿茶底):未显示 龙跃红(红袍乌龙茶底):最低155,最高325 轻芝沫红跃龙:最低265,最高365 水仙(玲珑水仙茶底):最低170,最高324 天青雨(天青雨乌龙茶底):最低160,最高280 绯红厚乳茶(绯红红茶底):最低205,最高340 老丛厚乳茶(老丛茶底):最低235,最高350 水牛乳焙茶(焙茶底):最低185,最高360 """

In [22]:
model = BERT_NER(args)

input_ = {
    "docid": "123456",
    "content": text,
    "add_keywords": True
}

model.predict_single(input_)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at hfl/chinese-roberta-wwm-ext and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


input text: 记录一下喜茶每杯单品的最高热量和最低热量。||可能的实体：喜茶||减脂期间发现喜茶的热量计算器可以测算不同糖度以及冰/热带来的热量差异了,汇总一下(一切以小xc官方为准),单位kcal 牛乳茶系列: 春光(四季春茶底):最低140,最高270 小奶茉(绿妍茶底):最低165,最高305 繁花(白兰蒙顶绿茶底):未显示 龙跃红(红袍乌龙茶底):最低155,最高325 轻芝沫红跃龙:最低265,最高365 水仙(玲珑水仙茶底):最低170,最高324 天青雨(天青雨乌龙茶底):最低160,最高280 绯红厚乳茶(绯红红茶底):最低205,最高340 老丛厚乳茶(老丛茶底):最低235,最高350 水牛乳焙茶(焙茶底):最低185,最高360 ||可能的实体：老丛厚乳茶,天青雨,龙跃红,绯红厚乳茶,牛乳茶,喜茶,水仙茶,绯红红茶,老丛茶,轻芝沫红跃龙,小奶茉,乳茶系列,水牛乳焙茶,绿妍茶底,乌龙茶,四季春茶|| 



{'docid': '123456',
 'entity': {'brand': ['喜茶'],
  'product': ['乌龙茶',
   '水仙茶',
   '轻芝沫红跃龙',
   '小奶茉',
   '乳茶系列',
   '老丛茶',
   '天青雨',
   '老丛厚乳茶',
   '龙跃红',
   '绯红厚乳茶',
   '牛乳茶',
   '绯红红茶',
   '四季春茶',
   '水牛乳焙茶',
   '绿妍茶底']},
 'details': [{'entity_type': 'brand',
   'entity': '喜茶',
   'index': [(4, 6), (27, 29)]},
  {'entity_type': 'product',
   'entity': '乌龙茶',
   'index': [(155, 158), (222, 225)]},
  {'entity_type': 'product', 'entity': '水仙茶', 'index': [(197, 200)]},
  {'entity_type': 'product', 'entity': '轻芝沫红跃龙', 'index': [(173, 179)]},
  {'entity_type': 'product', 'entity': '小奶茉', 'index': [(111, 114)]},
  {'entity_type': 'product', 'entity': '乳茶系列', 'index': [(83, 87)]},
  {'entity_type': 'product', 'entity': '老丛茶', 'index': [(271, 274)]},
  {'entity_type': 'product',
   'entity': '天青雨',
   'index': [(215, 218), (219, 222)]},
  {'entity_type': 'product', 'entity': '老丛厚乳茶', 'index': [(265, 270)]},
  {'entity_type': 'product', 'entity': '龙跃红', 'index': [(149, 152)]},
  {'entity_type

In [23]:
model = BERT_NER(args)
result_dict = {}

for input_ in tqdm(input_list):
    result = model.predict_single(input_)
    result_dict.update({result["docid"]: result["entity"]})

Some weights of BertForTokenClassification were not initialized from the model checkpoint at hfl/chinese-roberta-wwm-ext and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/284 [00:00<?, ?it/s]

input text: 霸王茶姬减脂红黑榜。||可能的实体：霸王茶姬||chagee霸王茶姬杀疯了!||可能的实体：chagee,霸王茶姬||10元一杯的伯牙绝弦又来了。||可能的实体：伯牙绝弦||。正在减脂,但是又想喝奶茶怎么办。霸王茶姬减脂红黑榜可以喝少喝霸王茶姬爱好者吐血整理减脂期霸王茶姬红黑榜可以喝少喝附热量亲测巨好喝爱喝奶茶的姐妹们收藏喔~可以喝伯牙绝弦桂馥兰香白雾红尘青青糯山去云南玫瑰普洱万山红金丝小种青沫观音山野栀子云中绿观音韵折桂令琥珀光七里香千峰翠海盐电解柠檬茶少喝桂子飘飘千山雪金丝小种杨枝甘露浮生梦媞云海芒芒奥利奥柑普茶冰川焦糖大红袍茶冰川恋恋山茶||可能的实体：桂子飘飘,千山雪金丝小种,折桂令,琥珀光,千峰翠,奥利奥柑普茶冰川,焦糖大红袍茶冰川,云中绿,霸王茶姬,桂馥兰香,云海芒芒,海盐电解柠檬茶,杨枝甘露,伯牙绝弦,白雾红尘,恋恋山茶,青沫观音,去云南玫瑰普洱,万山红金丝小种,观音韵,浮生梦媞,青青糯山,七里香,山野栀子|| 

input text: 独立小茶饼yyds一次一片每天新口味。安利一款小茶礼!![哇r][哇r][哇r] 首先包装好好看里面有30种不同的口味[色色r][色色r][色色r] 桂花红茶、茉莉红茶、百合红茶、冰岛红茶、柠檬红茶、玫瑰红茶、柠檬菊红、滇红龙珠白茶、玫瑰白茶、桂花白茶、薄荷白茶、荷叶白茶茉莉白茶、陈皮白茶、月光白茶、老白茶、糯香生茶、班章生茶、茉莉生茶、古树生茶冰岛生茶、黄金叶生茶、陈皮熟普、咖啡普洱、枣香熟普、陈年熟普糯香熟 牡丹熟普、菊花熟普、小青柑等 口味也太多了吧[喝奶茶r][喝奶茶r][喝奶茶r]不爱喝白开水的我可太喜欢啦,才50+平均一片不到2实在太值了叭 #||可能的实体：咖啡普洱,牡丹熟普,枣香熟普,柠檬红茶,冰岛红茶,白开水,老白茶,黄金叶生茶,古树生茶,荷叶白茶,茉莉生茶,薄荷白茶,桂花红茶,班章生茶,玫瑰红茶,小青柑,陈皮白茶,玫瑰白茶,糯香生茶,桂花白茶,茉莉红茶,冰岛生茶,滇红龙珠白茶,百合红茶,陈年熟普,月光白茶,柠檬菊红,菊花熟普,茉莉白茶,陈皮熟普||茶样试喝#普洱茶#||可能的实体：普洱茶||花茶#礼物#礼物推荐#送礼#过年送礼#伴手礼#茶样推荐#喜茶||可能的实体：喜茶|| 

input text: 五棵松水牛喝水实况。1. 煮叶 茉莉花茶冰沙/冰摇荔枝乌龙茶 喜欢喝茶的

In [24]:
df_data[["pred_brand", "pred_product"]] = df_data["doc_id"].apply(lambda x: [result_dict[x].get("brand"), result_dict[x].get("product")]).to_list()

  return asarray(a).ndim


In [25]:
df_data

Unnamed: 0,docid,context,human_brand,human_product,doc_id,context_cleaned,pred_brand,pred_product
0,2024022321789920833,"霸王茶姬减脂红黑榜。CHAGEE霸王茶姬杀疯了!10元一杯的伯牙绝弦又来了。。正在减脂,但是...",['霸王茶姬'],"['伯牙绝弦', '桂馥兰香', '白雾红尘', '青青糯山', '去云南玫瑰普洱', '万...",0,"霸王茶姬减脂红黑榜。chagee霸王茶姬杀疯了!10元一杯的伯牙绝弦又来了。。正在减脂,但是...","[chagee, 霸王茶姬]","[桂子飘飘, 千山雪金丝小种, 折桂令, 琥珀光, 千峰翠, 奥利奥柑普茶冰川, 焦糖大红袍..."
1,2024022121789570136,独立小茶饼yyds❗️一次一片每天新口味。安利一款小茶礼!![哇R][哇R][哇R] 首先包...,['喜茶'],"['桂花红茶', '茉莉红茶', '百合红茶', '冰岛红茶', '柠檬红茶', '玫瑰红茶...",1,独立小茶饼yyds一次一片每天新口味。安利一款小茶礼!![哇r][哇r][哇r] 首先包装好...,[喜茶],"[咖啡普洱, 牡丹熟普, 枣香熟普, 冰岛红茶, 柠檬红茶, 老白茶, 白开水, 黄金叶生茶..."
2,202402272963901,五棵松水牛喝水实况。1. 煮叶 茉莉花茶冰沙/冰摇荔枝乌龙茶 喜欢喝茶的uu请一定不要错过煮...,"['乐乐茶', '一点点', '茉酸奶', '霸王茶姬', 'manner', '茶百道',...","['茉莉花茶冰沙', '冰摇荔枝乌龙茶', '苹果桃子酪', '大口草莓桃', '抹茶奶茶'...",2,五棵松水牛喝水实况。1. 煮叶 茉莉花茶冰沙/冰摇荔枝乌龙茶 喜欢喝茶的uu请一定不要错过煮...,"[茶百道, 霸王茶姬, 宝珠奶酪, 奈雪的茶, manner, 一点点, 喜茶, 茉酸奶, ...","[桂花乌龙拿铁, 提拉米苏拿铁, 冰淇淋, 桂馥兰香, 生巧芝士, 豆乳米麻薯, 茉莉花茶冰..."
3,2024022321780757982,"石牌东「喜茶」9元开团!16款招牌任你选!数量有限,速度囤!。 \n\n长按进入购买 \n\...",['喜茶'],"['小奶茉', '绯红厚乳茶', '老丛厚乳茶', '多肉葡萄', '轻芝多肉葡萄', '酷...",4,"石牌东「喜茶」9元开团!16款招牌任你选!数量有限,速度囤!。 长按进入购买 9元抢喜茶新店...",[喜茶],"[波波真乳茶, 超多肉芒芒甘露, 芝芝绿妍茶后, 椰椰芒芒, 多肉葡, 老丛厚乳茶, 轻芒芒..."
4,2024022221776223342,"上海奶茶店推荐饮品。🌟门店少,路过请尝试 另茶:大马椰丸奶茶,雪域咸奶茶,法式玫瑰奶茶 红茶...","['茶为天', '霸王茶姬', '茶理宜世', '喜茶', '乐乐茶', 'cococean...","['大马椰丸奶茶', '雪域咸奶茶', '法式玫瑰奶茶', '路易博士牛乳茶', '正山醇'...",5,"上海奶茶店推荐饮品。门店少,路过请尝试 另茶:大马椰丸奶茶,雪域咸奶茶,法式玫瑰奶茶 红茶公...","[阿嬷奶茶, 茉莉奶白, 霸王茶姬, cococean, 吃茶三千, blueglass, ...","[正山醇, 陈香南糯, 开心果椰子糖, 法式玫瑰奶茶, 路易博士牛乳茶, 抹茶波霸脏脏茶, ..."
...,...,...,...,...,...,...,...,...
279,2024022621775573622,#霸王茶姬##霸王茶姬[超话]# 霸王茶姬代下单 千山雪金丝小种 去云南玫瑰普洱 万山红金丝...,['霸王茶姬'],"['千山雪金丝小种', '去云南玫瑰普洱', '万山红金丝小种', '桂子飘飘', '清沫观...",294,#霸王茶姬##霸王茶姬[超话]# 霸王茶姬代下单 千山雪金丝小种 去云南玫瑰普洱 万山红金丝...,[霸王茶姬],"[桂子飘飘, 千山雪金丝小种, 去云南玫瑰普洱, 万山红金丝小种, 白雾红尘, 清沫观音]"
280,2024022521781278164,Plog |过年喝咖啡☕️。一杯在海宁西田城的and coffee 西班牙拿铁 一杯在斜桥h...,"['霸王茶姬', '瑞幸']","['西班牙拿铁', '桂花拿铁咖啡', '伯牙绝弦', '樱花拿铁']",295,plog |过年喝咖啡。一杯在海宁西田城的and coffee 西班牙拿铁 一杯在斜桥hyg...,"[瑞幸, 霸王茶姬]","[樱花拿铁, 桂花拿铁咖啡, 伯牙绝弦, 西班牙拿铁]"
281,2024022421780233003,"生酮饮食计划-DAY2。上次进食时间是昨天1:30那个喜多多仙草丸子,第一天掉秤非常理想↓1...","['星巴克', '霸王茶姬']","['热美式', '云南普洱玫瑰茶']",296,"生酮饮食计划-day2。上次进食时间是昨天1:30那个喜多多仙草丸子,第一天掉秤非常理想↓1...","[星巴克, 霸王茶姬]","[喜多多, 热美式, 云南普洱玫瑰茶]"
282,2024022421770715205,福气满满人间灯圆🧨一起跟姬姐闹元宵🏮 天宫月圆🌕人间灯圆🏮岁岁年年皆平安~ 新年的第一轮圆月...,['霸王茶姬'],"['桂馥兰香', '青青糯山', '青沫观音']",297,福气满满人间灯圆一起跟姬姐闹元宵 天宫月圆人间灯圆岁岁年年皆平安~ 新年的第一轮圆月也是新的...,[霸王茶姬],"[乌龙茶, 桂花香, 桂馥兰香, 清香乌龙茶, 青青糯山, 青沫观音, 清香乌龙]"


In [26]:
df_data.to_excel(f"test.xlsx", index=False)

In [27]:
df_data.fillna('[]', inplace=True)

## eval

In [28]:
def checkIfOverlap(pred_val, true_val, text):
    rang_a = findBoundary(true_val, text)
    rang_b = findBoundary(pred_val, text)
    if len(rang_a) == 0 or len(rang_b) == 0:
        return False
    else:
        for i, j in rang_a:
            for k, m in rang_b:
                intersec = set(range(i, j)).intersection(set(range(k, m)))
                if len(intersec) > 0:
                    return True
                else:
                    return False

def findBoundary(val, text):
    res = []
    for i in range(0, len(text) - len(val) + 1):
        if text[i:i + len(val)] == val:
            res.append((i, i + len(val)))
    return res

In [29]:
def partial_match(pred_list, golden_list, text):
    precision_hit, precision_all = 0, 0
    recall_hit, recall_all = 0, 0
    
    for pred in pred_list:
        precision_all += 1
        for golden in golden_list:
            golden = golden.lower()
            if checkIfOverlap(pred, golden, text):
                precision_hit += 1
                break
    
    for golden in golden_list:
        golden = golden.lower()
        recall_all += 1
        for pred in pred_list:
            if checkIfOverlap(pred, golden, text):
                recall_hit += 1
                break

    return precision_hit, precision_all, recall_hit, recall_all

In [30]:
args.entity_type

['brand', 'product']

In [31]:
result = {}
for entity_type in tqdm(args.entity_type):
    p_hit, p_all, r_hit, r_all = 0, 0, 0, 0
    for _, row in df_data.iterrows():
        golden_list = eval(row[f'human_{entity_type}']) ##
        pred_list = row[f'pred_{entity_type}'] ##
        text = str(row['context_cleaned']) ##

        p_hit_, p_all_, r_hit_, r_all_ = partial_match(pred_list, golden_list, text)
        p_hit += p_hit_
        p_all += p_all_
        r_hit += r_hit_
        r_all += r_all_
    
    p = p_hit / p_all
    r = r_hit / r_all
    result.update(
        {
            f"{entity_type}": {
                "precision": round(p, 4),
                "recall": round(r, 4),
                "f1": round(2 * p * r / (p + r), 4),
                "details": [p_hit, p_all, r_hit, r_all]
            }
        }
    )

  0%|          | 0/2 [00:00<?, ?it/s]

In [32]:
result

{'brand': {'precision': 0.931,
  'recall': 0.8584,
  'f1': 0.8932,
  'details': [769, 826, 770, 897]},
 'product': {'precision': 0.7939,
  'recall': 0.8977,
  'f1': 0.8426,
  'details': [1803, 2271, 1825, 2033]}}

In [33]:
from ner_eval import *

In [34]:
infile = f"test.xlsx"
outfile = f"test_statistic.xlsx"
data_statistic_2("pred", infile, outfile)

no


In [35]:
outfile_1 = f"test_matrix.xlsx"
outfile_2 = f"test_evaluate_doc_level.xlsx"
outfile_3 = f"test_evaluate_entity_level.xlsx"
data_evaluate_2(outfile, outfile_1, outfile_2, outfile_3)