In [1]:
from typing import Union

import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet

import numpy as np

[nltk_data] Downloading package wordnet to
[nltk_data]     /home/davidsule/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [2]:
ENTITY_LABELS = "journal album algorithm astronomer award band book chemical conference country discipline election enzyme event field genre location magazine metrics miscellaneous artist instrument Organisation person poem politics politician product java protein researcher scientist song task theory university writer"
ENTITY_LABELS_SPLIT = ENTITY_LABELS.split()

In [3]:
def most_frequent_synset(entities: list) -> list:
    """Return list of most frequent synset for each word in entities."""
    return list(map(lambda w: wordnet.synsets(w)[0], entities))

def minpath(synset) -> list:
    """Return the path with the minimum length to the root."""
    hypernym_paths = synset.hypernym_paths()
    minlen = len(hypernym_paths[0])
    minlen_idx = 0
    for idx, path in enumerate(hypernym_paths):
        if len(path) < minlen:
            minlen = len(path)
            minlen_idx = idx
    return hypernym_paths[minlen_idx]

def category_dict(synset_list: list, level: int = 2, words: Union[list, None] = None) -> dict:
    """Create dictionary with desired level of hypernyms as keys and
    synsets as values.  Optionally pass the corresponding list of words
    to return the words as values.

    The function searches for all the possible paths to the root and
    chooses the shortes among those for each synset in synset_list.  The
    level means the distance from the root in the path of hypernyms;
    level 0 is the root.  If the path for a synset is shorter than the
    selected level, the last level is returned.

    Each hypernym category is a key in the returned dictionary, whose
    keys are the synsets belonging to it from synset_list (or the
    corresponding words if words in not None).
    """
    categories = {}
    for i, syn in enumerate(synset_list):
        path = minpath(syn)
        try:
            cat = path[level]
        except:
            cat = path[-1]
        if cat in categories:
            if words is not None:
                categories[cat].append(words[i])
            else:
                categories[cat].append(syn)
        else:
            if words is not None:
                categories[cat] = [words[i]]
            else:
                categories[cat] = [syn]
    return categories

# Get categories with most frequent word meanings
# category_dict(most_frequent_synset(ENTITY_LABELS_SPLIT), 2, ENTITY_LABELS_SPLIT)

In [4]:
# Check meanings and definitions for manual correction
idx = 36
synlist = wordnet.synsets(ENTITY_LABELS_SPLIT[idx])
print(idx, ENTITY_LABELS_SPLIT[idx])
for i, syn in enumerate(synlist):
    print(i, syn, syn.definition())

36 writer
0 Synset('writer.n.01') writes (books or stories or articles or the like) professionally (for pay)
1 Synset('writer.n.02') a person who is able to write and has written something


In [5]:
# Manual corrections of meanings. Key: idx in Entity_labels, value: idx of meaning in wordnet.synsets(word)
# ENTITY_LABELS = "journal album algorithm astronomer award band book chemical conference country discipline election enzyme event field genre location magazine metrics miscellaneous artist instrument Organisation person poem politics politician product java protein researcher scientist song task theory university writer"
meaning_dict = {0: 1, 1: 0, 2: 0, 3: 0, 4: 1, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0, 10: 0, 11: 0, 12: 0, 13: 0, 14: 3, 15: 2, 16: 0, 17: 0, 18: 3, 19: 0, 20: 0, 21: 5, 22: 0, 23: 0, 24: 0, 25: 0, 26: 0, 27: 0, 28: 2, 29: 0, 30: 0, 31: 0, 32: 0, 33: 1, 34: 0, 35: 2, 36: 0}

In [6]:
# Get categories with associated words
ENTITY_SYNS = []
for i, word in enumerate(ENTITY_LABELS_SPLIT):
    ENTITY_SYNS.append(wordnet.synsets(word)[meaning_dict[i]])
category_dict(ENTITY_SYNS, 2, ENTITY_LABELS_SPLIT)

{Synset('object.n.01'): ['journal',
  'album',
  'book',
  'location',
  'magazine',
  'instrument',
  'product'],
 Synset('psychological_feature.n.01'): ['algorithm',
  'discipline',
  'election',
  'event',
  'field',
  'task',
  'theory'],
 Synset('causal_agent.n.01'): ['astronomer',
  'artist',
  'person',
  'politician',
  'researcher',
  'scientist',
  'writer'],
 Synset('communication.n.02'): ['award', 'genre', 'poem', 'java', 'song'],
 Synset('group.n.01'): ['band',
  'conference',
  'country',
  'Organisation',
  'university'],
 Synset('matter.n.03'): ['chemical', 'enzyme'],
 Synset('measure.n.02'): ['metrics'],
 Synset('assorted.s.01'): ['miscellaneous'],
 Synset('relation.n.01'): ['politics'],
 Synset('thing.n.12'): ['protein']}

In [7]:
# Calculate Wu-Palmer Similarities
def get_wup_sim(entity_syns):
    """Get pairwise Wu-Palmer Similarities of the synsets from the list
    in the argument.  Returns a NumPy Array of shape (len(entity_syns),
    len(entity_syns)).
    """
    similarities = np.zeros((len(entity_syns), len(entity_syns)))
    for i, syn1 in enumerate(entity_syns):
        for j, syn2 in enumerate(entity_syns):
            if j > i:
                continue
            sim = syn1.wup_similarity(syn2)
            similarities[i, j] = sim
            similarities[j, i] = sim
    return similarities

similarities = get_wup_sim(ENTITY_SYNS)