In [201]:
import nltk
import pandas as pd
import datetime
from nltk.tokenize import word_tokenize
from nltk import pos_tag
from nltk.chunk import ne_chunk

In [192]:
gt_label_all = pd.read_excel('label_update.xlsx')
gt_id = gt_label_all['id']
gt_person = gt_label_all['person']
gt_org = gt_label_all['organization']
gt_loc = gt_label_all['location']
gt_date = gt_label_all['date']

In [193]:
def extract_entities(tree, entity_type):
    entities = []
    if hasattr(tree, 'label') and tree.label() == entity_type:
        entities.append(' '.join([child[0] for child in tree]))
    for subtree in tree:
        if type(subtree) is nltk.Tree:
            entities.extend(extract_entities(subtree, entity_type))
    return entities

In [194]:
def calculate_precision_recall(predicted_entities, ground_truth_entities):
    true_positives = len(set(predicted_entities) & set(ground_truth_entities))
    precision = true_positives / len(set(predicted_entities)) if len(predicted_entities) > 0 else 0
    recall = true_positives / len(set(ground_truth_entities)) if len(ground_truth_entities) > 0 else 0
                                  
    return precision, recall

In [188]:
# 不管了
def convert_time_to_set(time_item):
    if pd.isna(time_item):
        return None
    elif isinstance(time_item,datetime.datetime):
        return time_item.strftime('%Y-%m-%d')
    elif isinstance(time_item,str):
        time_list = [gt.strip() for gt in str(gt_date[num]).split(',') if gt.strip()]
        return time_list
    else:
        return None

gt_date_clean = list(gt_label_all['date'].apply(convert_time_to_set))

In [189]:
print(gt_date_clean[4])

2012-09-03


In [195]:
precision_people_list = []
recall_people_list = []
precision_organizations_list = []
recall_organizations_list = []
precision_locations_list = []
recall_locations_list = []
precision_dates_list = []
recall_dates_list = []

In [136]:
len(gt_id)

250

In [196]:
for num,file_path in enumerate(gt_id):
    file_path_comp = 'label_story/'+file_path
    with open(file_path_comp, 'r', encoding='utf-8') as file:
        file_lines = file.readlines()
    
    filtered_lines = [line.strip() for line in file_lines if line.strip()]
                      #if not (line.startswith('@highlight') 
                       #       or line.startswith('Scroll down for video')) 
                      #and line.strip()]
    article_text = ''.join(filtered_lines)
    
    words = word_tokenize(article_text) #对文本进行分词
    word_pos_tags = pos_tag(words)  #对文本进行词类标记
                      
    # 使用NLTK的命名实体识别器提取实体：NLTK提供了一个命名实体识别器（NER），可以用于提取实体
    # 使用ne_chunk函数来执行命名实体识别
    entity_tree = ne_chunk(word_pos_tags)
    
    people_entities = extract_entities(entity_tree, 'PERSON')
    location_entities = extract_entities(entity_tree, 'GPE')
    date_entities = extract_entities(entity_tree, 'DATE')
    organization_entities = extract_entities(entity_tree, 'ORGANIZATION')
    
    #去掉结尾的逗号、去掉单词间可能存在的空格、最后根据逗号split
    gt_person_clean = [gt.strip() for gt in str(gt_person[num]).split(',') if gt.strip()]
    gt_org_clean = [gt.strip() for gt in str(gt_org[num]).split(',') if gt.strip()]
    gt_loc_clean = [gt.strip() for gt in str(gt_loc[num]).split(',') if gt.strip()]
    gt_date_clean = [gt.strip() for gt in str(gt_date[num]).split(',') if gt.strip()]
    
    precision_people, recall_people = calculate_precision_recall(people_entities,gt_person_clean)
    precision_organizations, recall_organizations = calculate_precision_recall(organization_entities, gt_org_clean)
    precision_locations, recall_locations = calculate_precision_recall(location_entities,gt_loc_clean)
    precision_dates, recall_dates = calculate_precision_recall(date_entities,gt_date_clean)
    
    precision_people_list.append(precision_people)
    recall_people_list.append(recall_people)
    precision_organizations_list.append(precision_organizations)
    recall_organizations_list.append(recall_organizations)
    precision_locations_list.append(precision_locations)
    recall_locations_list.append(recall_locations)
    precision_dates_list.append(precision_dates)
    recall_dates_list.append(recall_dates)
    

In [197]:
len(precision_people_list)

250

In [198]:
result_dict = {'precision_people':precision_people_list,
               'recall_people':recall_people_list,
               'precision_organizations': precision_organizations_list,
               'recall_organizations':recall_organizations_list,
               'precision_locations':precision_locations_list,
               'recall_locations':recall_locations_list,
               'precision_dates':precision_dates_list,
               'recall_dates':recall_dates_list}

result_df = pd.DataFrame(result_dict,index=gt_id)

In [199]:
result_df

Unnamed: 0_level_0,precision_people,recall_people,precision_organizations,recall_organizations,precision_locations,recall_locations,precision_dates,recall_dates
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
0000800d9058217f6509d7e63ad475e2de0da611.story,0.416667,0.384615,0.071429,0.500000,0.200000,0.125000,0,0.0
0001d4ce3598e37f20a47fe609736f72e5d73467.story,0.071429,0.100000,0.153846,0.400000,0.142857,0.166667,0,0.0
0002067d13d3ca304e0bc98d04dde85d4091c55e.story,0.375000,0.428571,0.125000,0.500000,0.166667,0.333333,0,0.0
000219931d2c3aae55dc2acdc5f690d0c112ab17.story,0.461538,0.666667,0.125000,0.333333,0.666667,0.666667,0,0.0
00022dbfa44ccdb94c1dc06938047e258076cf75.story,0.333333,0.500000,0.071429,0.500000,0.166667,1.000000,0,0.0
...,...,...,...,...,...,...,...,...
76e545faae492f3a5e732e312d508d9c0af0cbae.story,0.400000,0.600000,0.000000,0.000000,0.000000,0.000000,0,0.0
77fc262c408f5284bd6dcbca25376f7e9e4b20a8.story,0.555556,0.833333,0.333333,0.666667,0.750000,0.500000,0,0.0
78e02695b0ba309168bf21d09d67532eaf1c00e1.story,0.428571,0.600000,0.166667,0.400000,0.571429,0.800000,0,0.0
79a4fbcabce694f7f14332c4375d5b58a53380aa.story,0.333333,1.000000,0.125000,0.250000,0.500000,0.428571,0,0.0


In [202]:
result_df.to_excel('nltk_output.xlsx', index=True)