In [122]:
import numpy as np
import spacy
from collections import Counter, defaultdict
from itertools import product, combinations
from sklearn.preprocessing import normalize
import settings
import json

from nltk.corpus import wordnet as wn
from tags import NounTags

## Load Model

In [6]:
data = json.load(open('data/coco_noun.tags'))
nlp = spacy.load(settings.SPACY_MODEL)
nounExtractor = NounTags(nlp, alpha=0.0)

### Check Hypernomin

In [123]:
def check_hyper(word1, word2):
    sysn1 = wn.synsets(word1)
    sysn2 = wn.synsets(word2)
    res = False
    for sys1, sys2 in product(sysn1, sysn2):
        if (sys2 in sys1.common_hypernyms(sys2)) or (sys1 in sys2.common_hypernyms(sys1)):
            res = True
            break
    return res

In [184]:
def check_hyper(words):
    """
    Check if one word ins hyperomin of the other.

    Args:
        words (list): a list of words to be checked

    Return:
        bool
    """
    check = True
    for word1, word2 in combinations(words, 2):
        sysn1 = wn.synsets(word1)
        sysn2 = wn.synsets(word2)
        check_hyp = False
        for sys1, sys2 in product(sysn1, sysn2):
            if (sys2 in sys1.common_hypernyms(sys2)) or \
               (sys1 in sys2.common_hypernyms(sys1)):
                check_hyp = True
                break
        if not check_hyp:
            check = False
            break
    return check

## Loading captions

In [194]:
id_ = '263823'
caption_list = data['train2014'][id_]['captions']
captions = [nlp(c) for c in caption_list]
caption_list

['A baseball player prepares to swing at the ball. ',
 'A ball player prepares to swing as the umpire and catcher look on.',
 'Two baseball players and an umpire during a game.',
 'A baseball player getting ready to swing at the ball. ',
 'A baseball game is being played with the batter up.']

#### Getting Nouns

In [195]:
tags, counts, rlocs = nounExtractor.preprocess(caption_list)
tags, counts, rlocs

(['ball', 'baseball', 'player', 'game', 'umpire', 'catcher', 'batter'],
 [3, 4, 4, 2, 2, 1, 1],
 [0.07142857142857142,
  0.09090909090909091,
  0.14285714285714285,
  0.18181818181818182,
  0.5,
  0.7142857142857143,
  0.7272727272727273])

#### token dict

In [196]:
# build a token dict indexed by lemma and POS tag (for noun and verbs)
token_dict = defaultdict(set)
for sent in captions:
    for token in sent:
        if token.pos_ in ['NOUN', 'VERB']:
            token_dict[token.lemma_, token.pos_].add(token)
token_dict = dict(token_dict)

# keep only repeated tokens
for k in list(token_dict):
    if len(token_dict[k]) == 1:
        del token_dict[k]

In [197]:
token_dict

{('ball', 'NOUN'): {ball, ball, ball},
 ('baseball', 'NOUN'): {baseball, baseball, baseball, baseball},
 ('be', 'VERB'): {is, being},
 ('game', 'NOUN'): {game, game},
 ('player', 'NOUN'): {player, player, player, players},
 ('prepare', 'VERB'): {prepares, prepares},
 ('swing', 'VERB'): {swing, swing, swing},
 ('umpire', 'NOUN'): {umpire, umpire}}

#### children token dict

In [198]:
# build children token dict indexed by syntactic function
child_dict = defaultdict(set)
for k, tokens in token_dict.items():
    for token in tokens:
        for ctoken in token.children:
            if ctoken.pos_ in ['NOUN', 'VERB']:
                child_dict[k, ctoken.dep_].add(ctoken)
child_dict = dict(child_dict)

# keep only repeated children
for k in list(child_dict):
    if len(child_dict[k]) == 1:
        del child_dict[k]

In [199]:
child_dict

{(('player', 'NOUN'), 'compound'): {ball, baseball, baseball, baseball},
 (('prepare', 'VERB'), 'nsubj'): {player, player},
 (('prepare', 'VERB'), 'xcomp'): {swing, swing}}

#### Merged tags

In [200]:
# merge tags keeping the highest scored
remove_idxs = set()
for tokens in child_dict.values():
    idxs = sorted(set(tags.index(t.lemma_) for t in tokens if t.lemma_ in tags))
    tags_to_check = [tags[i] for i in idxs]
    if len(idxs) > 1:
        check = check_hyper(tags_to_check)
#         check = True
#         for w0, w1 in combinations(tags_to_check, 2):
#             if not check_hyper(w0, w1):
#                 check = False
#                 break
        if check:
            idxs = sorted(idxs)
            remove_idxs.update(idxs[1:])
            for id_ in idxs[1:]:
                counts[idxs[0]] += counts[id_]

In [201]:
new_tags = [t for i, t in enumerate(tags) if i not in remove_idxs]
new_rlocs = [r for i, r in enumerate(rlocs) if i not in remove_idxs]
new_counts = [c for i, c in enumerate(counts) if i not in remove_idxs]

In [202]:
remove_idxs, tags 

({1}, ['ball', 'baseball', 'player', 'game', 'umpire', 'catcher', 'batter'])

In [203]:
new_tags

['ball', 'player', 'game', 'umpire', 'catcher', 'batter']