In [1]:
"""Heuristic wordnet baseline"""
###

'Heuristic wordnet baseline'

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
k_json_folder = '../puzzles/'

In [11]:
from decrypt.scrape_parse import (
    load_guardian_splits,
    load_guardian_splits_disjoint_hash
)

import random
from typing import *

import jellyfish

from multiset import Multiset
from nltk.corpus import wordnet as wn
from tqdm import tqdm

from decrypt.common.puzzle_clue import GuardianClue
from decrypt.common.util_wordnet import all_inflect
from decrypt.common import validation_tools as vt

In [8]:
# Wordnet functions to produce reverse dictionary sets

def normalize(lemma):
    """Wordnet returns words with underscores and hyphens. We replace them with spaces. This possibly does not work well with lemminflect."""
    return lemma.replace("_"," ").replace("-"," ")

def get_syns(w: str) -> Set[str]:
    """
    Get all synonyms of w
    """
    ret = set()
    for ss in wn.synsets(w):
        for l in ss.lemma_names():
            ret.add(normalize(l))
    return ret

def get_syns_hypo1(w: str) -> Set[str]:
    """
    Get all synonyms and hyponyms to depth 1
    """
    ret = set()
    for ss in wn.synsets(w):
        for l in ss.lemma_names():
            ret.add(normalize(l))
        for rel_ss in ss.hyponyms():
            for l in rel_ss.lemma_names():
                ret.add(normalize(l))
    return ret

def get_syns_hypo_all(w: str, include_hyper=False, depth=3) -> Set[str]:
    """
    Get all synonyms; hyponyms to depth, depth; and hypernyms to depth, depth,
    if include_hyper is True

    :param w: word to lookup
    :param include_hyper: whether to do hypernym lookup
    :param depth: how far to go in hyponym / hypernym traversal
    """
    ret = set()
    for ss in wn.synsets(w):
        for l in ss.lemma_names():
            ret.add(normalize(l))
        if include_hyper:
            for rel_ss in ss.closure(lambda s: s.hypernyms(), depth=depth):
                for l in rel_ss.lemma_names():
                    ret.add(normalize(l))
        for rel_ss in ss.closure(lambda s: s.hyponyms(), depth=depth):
            for l in rel_ss.lemma_names():
                ret.add(normalize(l))
    return ret

def get_first_and_last_word(c: GuardianClue):
    clue_words = c.clue.split(" ")
    return clue_words[0], clue_words[-1]


In [9]:
def pct_sim(str1, str2):
    max_len = max(len(str1), len(str2))
    lev = jellyfish.levenshtein_distance(str1, str2)
    return 1.0 - lev/max_len

def eval_wn(val_set: List[GuardianClue],
            fcn: Callable,
            do_fuzzy: bool,
            do_rank: bool = False,
            **fcn_kwargs):
    """
    :param val_set:
    :param fcn:
    :param do_fuzzy:
    :param fcn_kwargs:
    :return:
    """
    rng = random.Random()
    rng.seed(42)

    model_outputs = []
    for val_gc in tqdm(val_set):
        all_possible = set()

        # add the direct synonyms
        for w in get_first_and_last_word(val_gc):
            all_possible.update(list(fcn(w.lower(), **fcn_kwargs)))

        # potentially add lemmas
        if do_fuzzy:
            orig = all_possible.copy()
            for w in orig:
                all_possible.update(all_inflect(w, None))

        _, filtered = vt.filter_to_len(val_gc.soln_with_spaces, all_possible)
        filtered_final = [x[0] for x in filtered]   # go back to with spaces

        # jellyfish score
        # # if do_rank:
        # #     list_with_rank = []
        # #     for out in filtered_final:
        # #         score = pct_sim(out, val_gc.clue)
        # #         list_with_rank.append((out, score))
        # #     # sort
        # #     list_sorted = sorted(list_with_rank, key=lambda x: x[1], reverse=True)
        # #     # take the word not the score
        #     filtered_final = [x[0] for x in list_sorted]

        # simple character overlap
        if do_rank:
            list_with_rank = []
            mset = Multiset(val_gc.clue)
            for out in filtered_final:
                score = len(mset.intersection(Multiset(out)))
                list_with_rank.append((out, score))
            # sort
            list_sorted = sorted(list_with_rank, key=lambda x: x[1], reverse=True)
            # take the word not the score
            filtered_final = [x[0] for x in list_sorted]
        else:
            rng.shuffle(filtered_final)

        mp = vt.ModelPrediction(
            idx=val_gc.idx,
            input=val_gc.clue_with_lengths(),
            target=val_gc.soln_with_spaces,
            greedy="",
            sampled=filtered_final)

        mp.model_eval = vt.eval(mp)
        model_outputs.append(mp)

    return model_outputs





In [12]:
#################
# this is the primary baseline
######################

# naive set
def run_primary_wn_naive():
    _, _, (_, val_orig, test_orig) = load_guardian_splits(k_json_folder)
    out1 = eval_wn(val_orig, fcn=get_syns_hypo1, do_fuzzy=False, do_rank=True) # 1711
    vt.all_aggregate(out1, label='syns,hypo1; no fuzzy, ranked by char overlap')

    out2 = eval_wn(test_orig, fcn=get_syns_hypo1, do_fuzzy=False, do_rank=True) # 1711
    vt.all_aggregate(out2, label='syns,hypo1; no fuzzy, ranked by char overlap')

run_primary_wn_naive()


100%|██████████| 5518/5518 [00:10<00:00, 506.11it/s]
100%|██████████| 143991/143991 [00:00<00:00, 728162.12it/s]
  0%|          | 0/55783 [00:00<?, ?it/s]

[("length punct: '", 1),
 ('invalid: clue group', 7687),
 ('invalid: invalid start char (most are continuation clues)', 607),
 ('invalid: number in clue (commonly references another clue)', 7066),
 ('invalid: regexp', 75),
 ('invalid: soln length does not match specified lens (multi box soln)', 56),
 ('invalid: unrecognized char in clue (e.g. html)', 85),
 ('invalid: zero-len clue text after regexp', 15),
 ('length punct: ,', 24644),
 ('length punct: -', 4148),
 ('length punct: .', 8),
 ('length punct: /', 1),
 ('stat: parsed_puzzle', 5518),
 ('stat: total_clues', 143991),
 (1, 119956),
 (2, 20272),
 (3, 2957),
 (4, 686),
 (5, 112),
 (6, 8)]
Total clues: len(puzz_list)


100%|██████████| 55783/55783 [00:02<00:00, 20214.13it/s]


removed 1611 exact dupes
142380


100%|██████████| 28476/28476 [00:37<00:00, 761.80it/s] 
  0%|          | 92/28476 [00:00<00:31, 913.12it/s]

[('agg_filter_len_pre_truncate', 10.313948588284871),
 ('agg_filtered_few', 0.6430327293159152),
 ('agg_generate_few', 0.6430327293159152),
 ('agg_generate_none', 0.12824834948728753),
 ('agg_in_filtered', 0.13119820199466217),
 ('agg_in_sample', 0.1078803202697008),
 ('agg_sample_len', 5.7325467059980335),
 ('agg_sample_len_correct', 1.0),
 ('agg_sample_len_pre_truncate', 10.313948588284871),
 ('agg_sample_wordct_correct', 0.7355305072286205),
 ('agg_top_10_after_filter', 0.1078803202697008),
 ('agg_top_match', 0.028761061946902654),
 ('agg_top_match_len_correct', 0.8717516505127124),
 ('agg_top_match_none', 0.12824834948728753),
 ('agg_top_match_wordct_correct', 0.554291333052395),
 ('agg_top_sample_result_len_correct', 0.8717516505127124),
 ('agg_top_sample_result_wordct_correct', 0.554291333052395),
 ('filter_len_pre_truncate', 293700),
 ('filtered_few', 18311),
 ('generate_few', 18311),
 ('generate_none', 3652),
 ('in_filtered', 3736),
 ('in_sample', 3072),
 ('sample_len', 163240)

100%|██████████| 28476/28476 [00:40<00:00, 710.85it/s] 


[('agg_filter_len_pre_truncate', 10.353912066301447),
 ('agg_filtered_few', 0.6434541368169687),
 ('agg_generate_few', 0.6434541368169687),
 ('agg_generate_none', 0.12891557803062229),
 ('agg_in_filtered', 0.13260289366484057),
 ('agg_in_sample', 0.10756426464391067),
 ('agg_sample_len', 5.748911363955612),
 ('agg_sample_len_correct', 1.0),
 ('agg_sample_len_pre_truncate', 10.353912066301447),
 ('agg_sample_wordct_correct', 0.7319035343848118),
 ('agg_top_10_after_filter', 0.10756426464391067),
 ('agg_top_match', 0.025705857564264644),
 ('agg_top_match_len_correct', 0.8710844219693777),
 ('agg_top_match_none', 0.12891557803062229),
 ('agg_top_match_wordct_correct', 0.5543264503441495),
 ('agg_top_sample_result_len_correct', 0.8710844219693777),
 ('agg_top_sample_result_wordct_correct', 0.5543264503441495),
 ('filter_len_pre_truncate', 294838),
 ('filtered_few', 18323),
 ('generate_few', 18323),
 ('generate_none', 3671),
 ('in_filtered', 3776),
 ('in_sample', 3063),
 ('sample_len', 1637

In [None]:
# disjoint set
# soln_to_clue_map, all_clues, (train, val, test) = load_guardian_splits_disjoint_hash(DataDirs.Guardian.json_folder)

# naive set
def run_primary_wn_disj2():
    _, _, (_, val_orig, test_orig) = load_guardian_splits_disjoint_hash(k_json_folder)
    out1 = eval_wn(val_orig, fcn=get_syns_hypo1, do_fuzzy=False, do_rank=True) # 1711
    vt.all_aggregate(out1, label='syns,hypo1; no fuzzy, ranked by char overlap')

    out2 = eval_wn(test_orig, fcn=get_syns_hypo1, do_fuzzy=False, do_rank=True) # 1711
    vt.all_aggregate(out2, label='syns,hypo1; no fuzzy, ranked by char overlap')

run_primary_wn_disj2()
