In [1]:
import re
import json
import random
from collections import defaultdict, OrderedDict

In [2]:
data_dir = "/Users/lvzhiheng/Desktop/data/travel_KG"
baike_dir = os.path.join(data_dir, "带url的正文和摘要triple/第二轮")

In [3]:
coarse_types = ["菜品", "建筑", "景点", "老字号门店", "人物", "文物", "组织机构"]

# convert type names to full path names, e.g. "名人故居" -> "/景点/人文历史景点/名人故居"
parent_class = {}
with open(os.path.join(data_dir, "subclassof.txt"), "r") as file:
    for line in file:
        category, parent_category = line.strip().split()
        parent_class[category] = parent_category

# query up parent_class until root
full_type_name = {}
for category in parent_class:
    full_name = category
    parent_category = parent_class[category]
    while parent_category != "":
        full_name = f"{parent_category}/{full_name}"
        parent_category = parent_class[parent_category] if parent_category in parent_class else ""
    full_type_name[category] = "/" + full_name

full_type_name.update({type_name: f"/{type_name}" for type_name in coarse_types})

# for type_name, full_name in full_type_name.items():
#     print(type_name, "->", full_name)

# read entity types from files
entity_types = defaultdict(list)
with open(os.path.join(data_dir, "instanceOf_jd.txt"), "r") as file:
    for line in file:
        entity, type_name = line.strip().split()
        type_name = full_type_name[type_name]
        if type_name not in entity_types[entity]:
            entity_types[entity].append(type_name)
print(f"#entity with fine-grained types: {len(entity_types)}")

for type_name in coarse_types:
    file_path = os.path.join(data_dir, f"entities/{type_name}.txt")
    with open(file_path, "r") as file:
        for line in file:
            entity = line.strip()
            full_name = full_type_name[type_name]
            if full_name in entity_types[entity]:
                continue
            entity_types[entity].append(full_name)
print(f"#total entities: {len(entity_types)}")

# add all up-level types to entity_types
for entity, types in entity_types.items():
    expand_types = []
    for type_name in types:
        levels = type_name[1:].split("/")
        for idx in range(1, len(levels) + 1):
            parent_type = "/" + "/".join(levels[:idx])
            if parent_type not in expand_types:
                expand_types.append(parent_type)
    entity_types[entity] = expand_types

#entity with fine-grained types: 1830
#total entities: 75049


In [4]:
fet_test_file = os.path.join(data_dir, "test.json")

with open(fet_test_file, "r") as file:
    test_data = json.load(file)

test_entities = set()
test_examples = []
for example in test_data:
    sentence = example["sent"]
    # replace label with its full name
    labels = ["老字号门店" if label == "门店" else label for label in example["labels"]]
    labels = [full_type_name[label] for label in labels]
    example["labels"] = labels

    entity = sentence[example["start"]:example["end"]]
    test_entities.add(entity)
    if entity in entity_types and all(label in entity_types[entity] for label in labels):
        example["labels"] = entity_types[entity]
    test_examples.append(example)

In [5]:
def extract_sentences(doc):
    content = re.sub(r"([。？！])", r"\1\n", doc)
    sentences = []
    for sentence in content.split():
        if len(sentence) < 4:
            continue
        sentences.append(sentence)
    return sentences


examples = []
# DS to build FET dataset
for file_name in os.listdir(baike_dir):
    file_path = os.path.join(baike_dir, file_name)

    with open(file_path, "r") as file:
        for line in file:
            # name  (full name in baike)  url  abstract/article
            # (full name) / abstract / article may be absent
            parts = line.split("\t\t")
            article = parts[-1].strip()
            if len(article) == 0:
                continue

            paragraphs = article.split("::;")
            if paragraphs[0] not in ["AbstractHere", "ArticleHere"]:
                continue
            for paragraph in paragraphs[1:]:
                if paragraph.startswith("==") and paragraph.endswith("=="):
                    continue

                for sentence in extract_sentences(paragraph):
                    # find anchor links
                    anchor_spans = [match.span() for match in re.finditer(r"\[\[(.*?)\|(.*?)\]\]", sentence)]
                    if len(anchor_spans) == 0:
                        continue
                    text, entity_spans = [], []
                    end_point, length = 0, 0
                    for start, end in anchor_spans:
                        text.append(sentence[end_point:start])
                        length += start - end_point
                        # replace anchor link with raw text
                        entity, url = sentence[start:end][2:-2].split("|")
                        text.append(entity)
                        entity_spans.append([length, length + len(entity)])
                        length += len(entity)
                        end_point = end
                    text.append(sentence[end_point:])
                    text = "".join(text)

                    for start, end in entity_spans:
                        entity = text[start:end]
                        # TODO: more accurate ways to get entity types
                        if entity not in entity_types or entity in test_entities:
                            continue
                        examples.append(OrderedDict([
                            ("sent", text),
                            ("labels", entity_types[entity]),
                            ("start", start),
                            ("end", end)
                        ]))
        

In [6]:
train_file = os.path.join(data_dir, "FET/train.json")
dev_file = os.path.join(data_dir, "FET/dev.json")
test_file = os.path.join(data_dir, "FET/test.json")


def remove_duplication(examples):
    example_ids = set()
    non_dup_examples = []
    for example in examples:
        # same (sent, start, end) as the same example
        example_id = (example["sent"], example["start"], example["end"])
        if example_id in example_ids:
            continue
        non_dup_examples.append(example)
        example_ids.add(example_id)
    return non_dup_examples


examples = remove_duplication(examples)

# random split to train & dev
entity_examples = defaultdict(list)
for idx, example in enumerate(examples):
    entity = example["sent"][example["start"] : example["end"]]
    entity_examples[entity].append(idx)

all_entity = list(entity_examples.keys())
random.seed(1234)
random.shuffle(all_entity)

# ensure train dev no entity overlap
# TODO: better ways to compare entites (e.g. entity id) rather than text
train_examples, dev_examples = [], []
n_entity = 0
while n_entity < len(all_entity) and len(train_examples) < len(examples) * 0.9:
    entity = all_entity[n_entity]
    train_examples += [examples[idx] for idx in entity_examples[entity]]
    n_entity += 1

while n_entity < len(all_entity):
    entity = all_entity[n_entity]
    dev_examples += [examples[idx] for idx in entity_examples[entity]]
    n_entity += 1
print(f"#train_examples: {len(train_examples)}, ratio: {len(train_examples) / len(examples)}")
print(f"#dev_examples: {len(dev_examples)}, ratio: {len(dev_examples) / len(examples)}")


test_examples = remove_duplication(test_examples)
# train dataset has no "/景点/自然风光/池塘" label
test_examples = [example for example in test_examples if "/景点/自然风光/池塘" not in example["labels"]]

for target_file, split_examples in zip([train_file, dev_file, test_file],
                                       [train_examples, dev_examples, test_examples]):
    with open(target_file, "w") as writer:
        json.dump(split_examples, writer, ensure_ascii=False, indent=4)

#train_examples: 168484, ratio: 0.9000502152846779
#dev_examples: 18710, ratio: 0.09994978471532208


In [7]:
# collect entity url from baike anchor text
entity_url = {}
for file_name in os.listdir(baike_dir):
    file_path = os.path.join(baike_dir, file_name)

    with open(file_path, "r") as file:
        for line in file:
            for match in re.finditer(r"\[\[(.*?)\|(http.*?)\]\]", line):
                entity, url = match.group(1), match.group(2)
                entity_url[entity] = url

## Dataset statistics

In [8]:
# train data sentence length distribution
import numpy as np

def dataset_stat(examples):
    print("sentence length percentage")
    sent_lens = np.array([len(example["sent"]) for example in examples])
    for i in range(90, 101):
        print(i, np.percentile(sent_lens, i))

    label_set = set()
    n_fine_grained = 0
    for example in examples:
        label_set.update(example["labels"])
        if len(example["labels"]) > 1:
            n_fine_grained += 1

    print(f"#examples in dataset: {len(examples)}")
    print(f"#labels in dataset: {len(label_set)}")
    print(f"#fine grained examples (> 1 labels) ratio: {n_fine_grained / len(examples)}")


dataset_stat(train_examples)
dataset_stat(dev_examples)
dataset_stat(test_examples)
n_fine_grained_entity = sum(1 if len(types) > 1 else 0 for entity, types in entity_types.items())
print(f"ratio of entities having > 1 types: {n_fine_grained_entity / len(entity_types)}")

sentence length percentage
90 114.0
91 119.0
92 125.0
93 133.0
94 143.0
95 154.0
96 171.0
97 195.0
98 227.0
99 293.0
100 1838.0
#examples in dataset: 168484
#labels in dataset: 33
#fine grained examples (> 1 labels) ratio: 0.017894874290733838
sentence length percentage
90 108.0
91 113.0
92 119.0
93 125.0
94 134.0
95 145.0
96 159.0
97 182.0
98 214.0
99 263.0
100 963.0
#examples in dataset: 18710
#labels in dataset: 26
#fine grained examples (> 1 labels) ratio: 0.012292891501870658
sentence length percentage
90 101.0
91 105.0
92 107.0
93 111.0
94 114.0
95 118.0
96 124.96000000000004
97 128.0
98 133.0
99 140.0
100 150.0
#examples in dataset: 4677
#labels in dataset: 33
#fine grained examples (> 1 labels) ratio: 0.26726534103057514
ratio of entities having > 1 types: 0.03016695758770936


In [9]:
# TODO: DS entity disambigution
# train dev & test label

In [10]:
# check baike text coverage of fine-grained entities
fine_ents = set(entity for entity, types in entity_types.items() if len(types) > 1)
baike_ents = set()

for file_name in os.listdir(baike_dir):
    file_path = os.path.join(baike_dir, file_name)

    with open(file_path, "r") as file:
        for line in file:
            name = line.split("\t\t")[0]
            baike_ents.add(name)


In [11]:
len(fine_ents), len(baike_ents & fine_ents)

(2264, 1229)