# Entity linking

You are provided with the data from a knowledge graph and asked to annotate an input document using a general entity linking pipeline approach, consisting of mention detection, candidate selection, and disambiguation steps. 

In [12]:
import ipytest
import pytest

ipytest.autoconfig()

## 1) Mention detection

You are given an excerpt from a surface form dictionary in the format `SF_DICT[mention][entity] = count`, where `count` refers to the number of times `mention` was linked to `entity` in a given training corpus.
The total count of linked occurrences of a mention is given as the key `_total` (i.e., this is the number of times mention was linked to any entity in the training corpus).
Note that not all linked entities are listed in the dictionary, hence the counts do not necessarily sum up to `_total`.

In [13]:
SF_DICT = {
    "1992 elections": {
        "wikipedia:Philippine_general_election,_1992": 9,
        "wikipedia:Angolan_presidential_election,_1992": 1,
        "_total": 98
    },
    "angola": {
        "wikipedia:Angola": 4026,
        "wikipedia:Angola_(Portugal)": 6,
        "wikipedia:Angola_national_football_team": 120,
        "_total": 4298
    },
    "democracy": {
        "wikipedia:Democracy": 108,
        "wikipedia:Democracy_(album)": 3,
        "_total": 2162
    },
    "multiparty democracy": {
        "wikipedia:multiparty_democracy": 11,
        "_total": 11
    },
    "one party": {
        "wikipedia:Non-possessors": 1,
        "wikipedia:Single-party_state": 5,
        "_total": 983
    }
}

In [14]:
TEXT = (
    "Angola changed from a one-party Marxist-Leninist system "
    "ruled by the MPLA to a formal multiparty democracy "
    "following the 1992 elections"
).lower()

We perform mention detection based on the following heuristic:

  - At each term position
    - Start with the longest possible n-gram (n = 8). 
    - If the n-gram is found in the dictionary, the mention and the corresponding entities are kept (and the shorter n-grams are ignored). Otherwise, we try to match the (n-1)-grams. 
    - Repeat until a mention is found or n reaches 0.

In [15]:
def detect_mentions(text, sf_dict):
    """Performs mention detection in text against a given surface form dictionary.
    
    Args:
        text: Input text.
        sf_dict: Surface form dictionary.
    
    Returns:
        List of matches as `(pos, mention, entity)` tuples ordered by pos, mention, and entity.
        (Term positions are indexed from 0.)
    """
    matches = []
    tokens = text.split()
    for pos, term in enumerate(tokens):
        n = max(8, len(tokens) - pos)
        while n > 0:
            # Check for matching n-gram
            n_gram = " ".join(tokens[pos:pos+n])
            if n_gram in sf_dict:
                for entity in sorted(sf_dict[n_gram].keys()):
                    if entity != "_total":
                        matches.append((pos, n_gram, entity))
                break
            n -= 1
    return matches

Tests.

In [16]:
%%run_pytest[clean]

def test_detect_mentions():
    assert detect_mentions(TEXT, SF_DICT) == [
        (0, "angola", "wikipedia:Angola"),
        (0, "angola", "wikipedia:Angola_(Portugal)"),
        (0, "angola", "wikipedia:Angola_national_football_team"),
        (14, "multiparty democracy", "wikipedia:multiparty_democracy"),
        (15, "democracy", "wikipedia:Democracy"),
        (15, "democracy", "wikipedia:Democracy_(album)"),
        (18, "1992 elections", "wikipedia:Angolan_presidential_election,_1992"),
        (18, "1992 elections", "wikipedia:Philippine_general_election,_1992")
    ]

.                                                                                  [100%]
1 passed in 0.01s


## 2) Entity ranking

Entity ranking is based on the commonness score:

$$Commonness(e, m) = p(e|m) = \frac{n(m, e)}{\sum_{e'} n(m, e')}$$

where $n(m, e)$ denotes the number of times entity $e$ is the link target of mention $m$.

In [17]:
def commonness(entity, mention, sf_dict):
    """Computes the commonness for an entity-mention pair given a surface form dictionary.
    
    Args:
        entity: Entity.
        mention: Mention.
        sf_dict: Surface form dictionary (containing entity-mention count statistics).
        
    Returns:
        Commonness (float).    
    """
    if mention not in sf_dict:
        return None
    return sf_dict[mention].get(entity, 0) / sf_dict[mention]["_total"]

In [18]:
def rank_entities(mentions, sf_dict, k=5):
    """Ranks candidate entities for each mention based on commonness and retains 
    the top-k highest-scoring entities for each mention.
    
    Args:
        mentions: Detected mentions (list of `(mention, entity, pos)` tuples).
        sf_dict: Surface form dictionary.
        k: Number of top-scoring entities to keep for each mention.
        
    Returns:
        Candidate entities with scores for each mention. Each mention is 
        represented as a dict `{'mention': xxx, 'pos': yyy, 'entities': zzz`,
        where entities is a list of `(entity, score)` tuples ordered by score desc.
    """
    # Reorganize input for more convenient processing.
    entities_per_mention = {}
    for (pos, mention, entity) in mentions:
        key = "{}::{}".format(pos, mention)
        if key not in entities_per_mention:
            entities_per_mention[key] = []
        entities_per_mention[key].append(entity)
    
    # Score all candidate entities for each mention.
    mentions_entities = []
    for key, entities in entities_per_mention.items():
        pos, mention = key.split("::")
        entity_scores = sorted([(entity, commonness(entity, mention, sf_dict))
                               for entity in entities], key=lambda x: x[1], reverse=True)
        mentions_entities.append({
            "mention": mention,
            "pos": int(pos),
            "entities": entity_scores[:k]
        })
    return mentions_entities

Tests.

In [19]:
%%run_pytest[clean]

@pytest.mark.parametrize("mention,entity,correct_value", [
    ("1992 elections", "wikipedia:Philippine_general_election,_1992", 9/98),
    ("1992 elections", "wikipedia:Angolan_presidential_election,_1992", 1/98),
    ("angola", "wikipedia:Angola", 4026 / 4298),
    ("angola", "wikipedia:Angola_national_football_team", 120 / 4298),
    ("democracy", "wikipedia:Democracy", 108/2162),
    ("democracy", "wikipedia:Democracy_(album)", 3/2162),
    ("multiparty democracy", "wikipedia:multiparty_democracy", 1)    
])
def test_commonness(entity, mention, correct_value):    
    assert commonness(entity, mention, SF_DICT) == pytest.approx(correct_value, rel=1e-3)
    
def test_rank_entities():
    mentions = detect_mentions(TEXT, SF_DICT)
    ranked_entities = rank_entities(mentions, SF_DICT, k=2)
    assert ranked_entities[0] == {"mention": "angola",
                                  "pos": 0,
                                  "entities": [
                                      ("wikipedia:Angola", 0.9367147510469986),
                                      ("wikipedia:Angola_national_football_team", 0.02791996277338297)
                                  ]
                                 }
    assert ranked_entities[1] == {"mention": "multiparty democracy",
                                  "pos": 14,
                                  "entities": [("wikipedia:multiparty_democracy", 1.0)]}

........                                                                           [100%]
8 passed in 0.06s


## 3) Disambiguation

Perform disambiguation by simply returning the entity for each mention with the highest score and only if it is above the given threshold.

In case of containment or overlapping mentions, keep only the one with the higher score.

In [20]:
def disambiguate(ranked_entities, threshold=0.1):
    """Disambiguates entities for each mention by keeping only the highest-scoring one.
    
    Args:
        ranked_entities: List of mentions along with a ranked list of candidate entities.
        threshold: Score threshold
    
    Returns:
        Entity annotations as a list of `(pos, mention, entity)` tuples.
    """
    # For each term position, we keep track of the highest scoring entity
    # that is linked to a mention in that position. We can greedily replace 
    # the annotations in case a higher scoring one is found.
    annotations = {}
    for candidates in ranked_entities:
        (entity, score) = candidates["entities"][0]
        if score < threshold:
            continue
        start_pos = candidates["pos"]
        mention = candidates["mention"]
        mention_length = len(mention.split())
        # Add mention-entity annotation if all term positions are
        # empty or are lower scoring.
        add_annotation = True
        for pos in range(start_pos, start_pos + mention_length):
            if pos in annotations:
                if annotations[pos]["score"] > score:
                    add_annotation = False
        
        if add_annotation:
            # For each term position, check if there is an existing 
            # annotation to be replaced.
            for pos in range(start_pos, start_pos + mention_length):
                if pos in annotations:
                    print("Replace on ", pos)
                    # Replace existing annotation.
                    start_pos_old = annotatations[pos]["start_pos"]
                    mention_length_old = annotatations[pos]["mention_length"]
                    for i in range(start_pos_old, start_pos_old + mention_length_old):
                        del annotations[i]
                
                # Store new annotation.
                annotations[pos] = {
                    "score": score,
                    "entity": entity,
                    "mention": mention,
                    "start_pos": start_pos,
                    "mention_length": mention_length
                }

    # Converting output to desired format.
    linked_entities = []
    for pos, annotation in sorted(annotations.items()):
        if pos == annotation["start_pos"]:
            linked_entities.append((pos, annotation["mention"], annotation["entity"]))
    return linked_entities

Tests.

In [21]:
%%run_pytest[clean]

def test_disambiguate():
    mentions = detect_mentions(TEXT, SF_DICT)
    ranked_entities = rank_entities(mentions, SF_DICT)
    linked_entities = disambiguate(ranked_entities, threshold=0.01)
    assert linked_entities == [
        (0, "angola", "wikipedia:Angola"),
        (14, "multiparty democracy", "wikipedia:multiparty_democracy"),
        (18, "1992 elections", "wikipedia:Philippine_general_election,_1992")
    ]

.                                                                                  [100%]
1 passed in 0.01s
