# IMPORTS

In [None]:
%load_ext cython

In [None]:
import cfg

import cPickle
import gzip
import os
import pandas as pd
import pattern.en
import time
from collections import OrderedDict

import progress_bar as pb
from documents_utils import doc_generator_from_file
from normalize_text import normalize_text, get_stopword_set

from pattern_matching.pattern_matcher import PyPatternMatcher, PyPatternMatches
from pattern_matching.segmenter import PySegmenter

# LOAD ALL THE DOCUMENTS, QUERIES AND ASSOCIATIONS IN MEMORY

In [None]:
%%time
query_id_to_query = cPickle.load(open(cfg.raw_dir + "training/query_id_to_query.pickle", "rb"))
query_id_to_doc_id_list = cPickle.load(open(cfg.raw_dir + "training/query_id_to_doc_id_list.pickle", "r"))
doc_id_to_raw_text = cPickle.load(open(cfg.raw_dir + "training/doc_id_to_raw_text.pickle", "r"))

assert len(query_id_to_query) == len(query_id_to_doc_id_list)
assert all(doc_id in doc_id_to_text for doc_id_list in query_id_to_doc_id_list.itervalues() for doc_id in doc_id_list)

In [None]:
%%time
doc_id_to_text = dict((doc_id, normalize_text(raw_text)) for doc_id, raw_text in doc_id_to_raw_text.iteritems())

In [None]:
# revert the query_id_to_doc_id_list associations
doc_id_to_query_id_list = dict()
for query_id, doc_id_list in query_id_to_doc_id_list.iteritems():
    for doc_id in doc_id_list:
        if doc_id in doc_id_to_query_id_list:
            doc_id_to_query_id_list[doc_id].append(query_id)
        else:
            doc_id_to_query_id_list[doc_id] = [query_id]

## Definitions

In [None]:
%%time
good_terms = set(line.strip() for line in gzip.open(cfg.raw_dir + "frequent_terms.txt.gz"))
print len(good_terms)

In [None]:
def is_good_expansion(expansion):
    global good_terms

    if " " in expansion:
        return all((term in good_terms) for term in expansion.split())
    else:
        return expansion in good_terms

In [None]:
%%cython

import pattern.en

def term_to_lemma(term, pos):
    if " " in term:
        lemma = ' '.join(pattern.en.lemma(t) or t for t in term.split())
    else:
        lemma = pattern.en.lemma(term)
    return str(lemma).strip()

def term_to_plural(term, pos):
    return str(pattern.en.pluralize(term, pos)).strip()

In [None]:
%%cython

def query_match(or_query, text):
    return any(  # or level: at least one of the OR macro terms must match the text
        all(  # and level: each synset must match the text
            any(  # or level: at least one of the word in the synset must be in the text
                (" " + word_and_tags[0] + " ") in text
                for word_and_tags in synset
            ) for synset in and_query
        ) for and_query in or_query
        if len(and_query) > 0
    )

# REWRITING STRATEGIES SUPPORT

In [None]:
# collection dependent term-statistics. This dictionary depends from the dataset
term_to_df = cPickle.load(open(cfg.processed_dir + "term_to_df.pickle", "r"))

In [None]:
%time segments_thesaurus = frozenset(line[:-1] for line in open(cfg.thesaurus_dir + "thesaurus.dict"))

In [None]:
stopwords = get_stopword_set()

In [None]:
%time expansion_support = cPickle.load(open(thesaurus_dir + "expansion_support.pickle", "rb"))

%time expansion_support["segment_to_segment_id"] = dict((segment, segment_id) for segment_id, segment in enumerate(expansion_support["segment_id_to_segment"]))
assert len(expansion_support["segment_to_segment_id"]) == len(expansion_support["segment_id_to_segment"])

In [None]:
%%time
pos_to_lemma_to_segment_id_set = {'adj': {}, 'adv': {}, 'noun': {}, 'verb': {}}

for segment_id, meaning_id_list in expansion_support['segment_id_to_meaning_id_list'].iteritems():
    term = expansion_support['segment_id_to_segment'][segment_id]
    # iterate over the possible meanings and take the pos tags
    for pos in set(expansion_support['meaning_id_to_pos_segment_id_list'][meaning_id][0] for meaning_id in meaning_id_list):
        if pos not in pos_to_lemma_to_segment_id_set:
            continue
        lemma = term_to_lemma(term, pos)

        # update the dictionaries
        if lemma in pos_to_lemma_to_segment_id_set[pos]:
            pos_to_lemma_to_segment_id_set[pos][lemma].add(segment_id)
        else:
            pos_to_lemma_to_segment_id_set[pos][lemma] = set([segment_id])

In [None]:
%%time
collapsed_segment_to_segment_id_list = dict()

for segment_id in xrange(max(
    len(expansion_support['segment_id_to_entity_id_tags_list']),
    1
)):
    segment = expansion_support['segment_id_to_segment'][segment_id]
    if " " in segment:
        new_segment = segment.replace(" ", "")
        if new_segment in expansion_support['segment_to_segment_id']:
            continue

        if new_segment in collapsed_segment_to_segment_id_list:
            collapsed_segment_to_segment_id_list[new_segment] += (segment_id,)
        else:
            collapsed_segment_to_segment_id_list[new_segment] = (segment_id,)

In [None]:
%%cython

def group_or_terms(or_term):
    term_to_tags = dict()
    for term, tags in or_term:
        if term not in term_to_tags:
            term_to_tags[term] = tags
        else:
            term_to_tags[term] += tuple(tag for tag in tags if tag not in term_to_tags[term])

    return [
        (term, tags)
        for term, tags in term_to_tags.iteritems()
    ]

In [None]:
def get_source_term(term):
    return (term,)

In [None]:
def filter_expansions(term_tags_list, query_terms):
    return [
        term_tags
        for term_tags in term_tags_list
        if all(term_tags[0] != query_term for query_term in query_terms)
    ]

In [None]:
def remove_stopwords(query, query_segmenter=None):
    # create a backup of the query
    query_backup = query
    # remove the stop words according to if they belong to some entity or not
    query = filter((lambda x: x not in stopwords), (query_segmenter.segment(query) if query_segmenter else query.split()))

    # ACK: if the query is composed only by stopwords use all the terms as query
    if len(query) == 0:
        query = query_backup
    else:
        # discard the previous segmentation
        query = " ".join(query)

    return query

In [None]:
def get_thesaurus_expansions(term):
    pos_set = pos_to_lemma_to_segment_id_set.keys()

    # get the LEMMA for each possible pos tag
    pos_to_lemma = dict(
        (pos, term_to_lemma(term, pos))
        for pos in pos_set
    )
    # filtering unlikely lemmas
    if False:
        for pos in pos_set:
            if pos_to_lemma[pos] not in pos_to_lemma_to_segment_id_set[pos]:
                del pos_to_lemma[pos]

    # use the lemma only if this term doesn't appear in our segments
    if False and term in expansion_support['segment_to_segment_id']:
        segment_id = expansion_support['segment_to_segment_id'][term]
        if segment_id in expansion_support['segment_id_to_meaning_id_list']:
            meaning_pos_set = set(
                expansion_support['meaning_id_to_pos_segment_id_list'][meaning_id][0]
                for meaning_id in expansion_support['segment_id_to_meaning_id_list'][segment_id]
            )
        else:
            meaning_pos_set = set()

        segment_id_set = set([segment_id])
        pos_to_normalized_segment_id_set = dict(
            (pos, segment_id_set if pos in meaning_pos_set else set())
            for pos in pos_set
        )
    else:
        # find possible NORMALIZED versions of the lemma for each pos tag
        pos_to_normalized_segment_id_set = dict(
            #(pos, pos_to_lemma_to_segment_id_set[pos][lemma])
            (pos, pos_to_lemma_to_segment_id_set[pos][lemma] if lemma in pos_to_lemma_to_segment_id_set[pos] else set())
            for (pos, lemma) in pos_to_lemma.iteritems()
        )

    pos_to_normalized_term_set = dict(
        (pos, set(expansion_support['segment_id_to_segment'][segment_id] for segment_id in normalized_segment_id_list))
        for (pos, normalized_segment_id_list) in pos_to_normalized_segment_id_set.iteritems()
    )

    # get the SYNONYMS of each normalized version
    pos_to_synset = dict()
    for pos, normalized_segment_id_list in pos_to_normalized_segment_id_set.iteritems():
        pos_to_synset[pos] = set(
            expansion_support['segment_id_to_segment'][segment_id]
            for normalized_segment_id in normalized_segment_id_list
            for meaning_id in expansion_support['segment_id_to_meaning_id_list'][normalized_segment_id]
                if normalized_segment_id in expansion_support['segment_id_to_meaning_id_list']
                and pos == expansion_support['meaning_id_to_pos_segment_id_list'][meaning_id][0]
            for segment_id in expansion_support['meaning_id_to_pos_segment_id_list'][meaning_id][1]
                if (" " + expansion_support['segment_id_to_segment'][normalized_segment_id] + " ") not in (" " + expansion_support['segment_id_to_segment'][segment_id] + " ")  # discard synonyms that extend the starting term with additional terms
        )

    # get the PLURALS of the normalized terms and their synonyms (which should be in the singular form)
    terms_to_pluralize = set()
    if "noun" in pos_to_normalized_term_set:
        terms_to_pluralize.update(pos_to_normalized_term_set["noun"])
    if "noun" in pos_to_synset:
        terms_to_pluralize.update(pos_to_synset["noun"])

    noun_plurals = set(
        term_to_plural(new_term, "noun")
        for new_term in terms_to_pluralize
    )

    # put all togheter
    res = group_or_terms(
        [
            (lemma, (pos, "Lem"))
            for (pos, lemma) in pos_to_lemma.iteritems()
        ] + [
            (normalized_term, (pos, "Norm"))
            for (pos, normalized_terms_set) in pos_to_normalized_term_set.iteritems()
            for normalized_term in normalized_terms_set
        ] + [
            (synonym, (pos, "Syn"))
            for (pos, synonyms_set) in pos_to_synset.iteritems()
            for synonym in synonyms_set
        ] + [
            (noun_plural, ("noun", "Plu"))
            for noun_plural in noun_plurals
        ]
    )

    return [
        (synonym, tags)
        for (synonym, tags) in res
        if (" " + term + " ") not in (" " + synonym + " ")  # remove synonyms that contains the original term
            and is_good_expansion(synonym)
    ]

In [None]:
def _get_entity_expansions(segment_id):
    if segment_id >= len(expansion_support["segment_id_to_entity_id_tags_list"]):
        return []

    res = [
        (expansion_support["segment_id_to_segment"][new_segment_id], expansion_support["entity_id_to_tags_segment_id_list"][entity_id][0] + tags)
        for entity_id, tags in expansion_support["segment_id_to_entity_id_tags_list"][segment_id]
        for new_segment_id in expansion_support["entity_id_to_tags_segment_id_list"][entity_id][1]
    ]

    segment_src = expansion_support["segment_id_to_segment"][segment_id]
    return [
        (segment, tags)
        for segment, tags in res
        if (" " + segment_src + " ") not in (" " + segment + " ")  # remove synonyms that contains the original term
    ]

def get_entity_expansions(segment):
    segment_id = expansion_support["segment_to_segment_id"].get(segment, None)

    if segment_id is None:
        if " " not in segment and segment in collapsed_segment_to_segment_id_list:
            # TEMP CODE
            return sum([
                _get_entity_expansions(new_segment_id)
                for new_segment_id in collapsed_segment_to_segment_id_list[segment]
            ], [])

        return []

    return _get_entity_expansions(segment_id)

# REWRITING STRATEGIES

In [None]:
# creates the query representation without any expansion
def query_to_base(query):
    # normalize the text
    query = normalize_text(query)

    # remove the stop words according to if they belong to some entity or not
    query = remove_stopwords(query, query_segmenter=None)

    # tokenize the query
    query = query.split()

    # simulate the "synset" to match the signature. The synset is composed only by the term istelf (and its tag)
    query = map((lambda term: [get_source_term(term)]), query)

    # the expanded query is composed only by this segmentation
    return [query]

In [None]:
# expand using the thesaurus and the entities, but segmenting the query before
def get_query_to_segmented_thesaurus_expansion(min_segmentation_freq):
    global segments_thesaurus, term_to_df
    full_segmenter = PySegmenter(
        set(
            segment[1:-1]
            for segment in term_to_df
            if segment[0]==segment[-1]=="\"" and segment[1:-1] in segments_thesaurus
        ),
        term_to_df,
        -1.0,
        min_segmentation_freq
    )

    def _query_to_segm_ent_exp(query):
        # normalize the text
        query = normalize_text(query)

        # remove the stop words according to if they belong to some entity or not
        query = remove_stopwords(query, query_segmenter=full_segmenter)

        # segment using entities and thesaurus words (the order is important)
        query_terms = full_segmenter.segment(query)

        # create synset
        query = map(
            (lambda t: [get_source_term(t)] + filter_expansions(group_or_terms(get_thesaurus_expansions(t) + get_entity_expansions(t)), query_terms)),
            query_terms
        )

        # the expanded query is composed only by this segmentation
        return [query]

    return _query_to_segm_ent_exp

In [None]:
strategies = OrderedDict([
        ("Base", query_to_base),
        ("SegmentedThesaurusExpansion(100)", get_query_to_segmented_thesaurus_expansion(100)),
    ])

In [None]:
%%time
all_strategy_name_to_query_id_to_query = OrderedDict()
keys = []
table = []
for strategy_name, strategy in strategies.iteritems():
    all_strategy_name_to_query_id_to_query[strategy_name] = dict()

    start_time = time.time()
    for num, query_id in pb.iter_progress(enumerate(iterator), size=len(query_id_to_query), labeling_fun={"prefix":strategy_name}, hide_bar_on_success=True):
        all_strategy_name_to_query_id_to_query[strategy_name][query_id] = strategy(query_id_to_query[query_id])
    keys.append(strategy_name)
    table.append([1.0 * (time.time()-start_time) / len(query_id_to_query)])
# it lasts 7min 30s

In [None]:
pd.DataFrame(table, index=keys, columns=["Avg. expansions time"])

# COMPUTE THE NUMBER OF MATCHES OF THE EXPANDED QUERIES

## Compute the number of matches of each rewrite

In [None]:
%%time
strategy_name_to_query_id_to_num_match = OrderedDict()

for strategy_name in strategies:
    strategy_name_to_query_id_to_num_match[strategy_name] = dict((query_id, 0) for query_id in query_id_to_doc_id_list)

for doc_id, doc_text in pb.iter_progress(doc_id_to_text.iteritems(), size=len(doc_id_to_text)):
    if doc_id not in doc_id_to_query_id_list:
        continue
    for query_id in doc_id_to_query_id_list[doc_id]:
        # for each strategy check if the query matchs the document
        for strategy_name, _query_id_to_query in all_strategy_name_to_query_id_to_query.iteritems():
            strategy_name_to_query_id_to_num_match[strategy_name][query_id] += \
                query_match(_query_id_to_query[query_id], doc_text)

# GROUND TRUTH BUILD (using the same format used previously)

In [None]:
strategy_name = "SegmentedThesaurusExpansion(100)"

In [None]:
assert strategy_name in strategies

In [None]:
# for the training of the models we consider only the queries having at least one candidate expansion that can improve its recall.
queries_with_recall_improvement = [
    query_id
    for query_id in query_id_to_num_match
    if strategy_name_to_query_id_to_num_match[strategy_name][query_id] > strategy_name_to_query_id_to_num_match["Base"][query_id]
]

## COMPUTE THE WORD OCCURRENCES OF EACH QUERY, NEEDED BY THE TRAINING

In [None]:
%%cython

def compute_word_occurrence_set(expanded_query, doc_id_list, doc_id_to_text):
    word_set = set(
        word_and_tags[0]
        for and_query in expanded_query
        for synset in and_query
        for word_and_tags in synset
    )

    return dict(
        (word, set(doc_id
                   for doc_id in doc_id_list
                   if (" " + word + " ") in doc_id_to_text[doc_id]))
        for word in word_set)

In [None]:
%%time
query_id_to_word_to_occurrence_set = dict()

for query_id in pb.iter_progress(query_id_to_query):
    if query_id < 0:
        continue
    query_id_to_word_to_occurrence_set[query_id] = compute_word_occurrence_set(
        all_strategy_name_to_query_id_to_query[strategy_name][query_id],
        query_id_to_doc_id_list[query_id],
        doc_id_to_text
    )

In [None]:
%%time

for query_id in pb.iter_progress(query_id_to_query):
    if query_id < 0:
        continue
    query_id_to_word_to_occurrence_set[query_id].update(compute_word_occurrence_set(
        all_strategy_name_to_query_id_to_query["Base"][query_id],
        query_id_to_doc_id_list[query_id],
        doc_id_to_text
    ))

## SAVE THE GROUND TRUTH

In [None]:
if not os.isdir(cfg.processed_dir + "training/"):
    os.mkdir(cfg.processed_dir + "training/")

In [None]:
%%time
with open(cfg.processed_dir + "training/expanded_query.queries_with_recall_improvement.pickle", "wb") as outfile:
    cPickle.dump(queries_with_recall_improvement, outfile, protocol=cPickle.HIGHEST_PROTOCOL)

In [None]:
%%time
with open(cfg.processed_dir + "training/expanded_query.query_id_to_word_to_occurrence_set.pickle", "wb") as outfile:
    cPickle.dump(query_id_to_word_to_occurrence_set, outfile, protocol=cPickle.HIGHEST_PROTOCOL)

In [None]:
%%time
with open(cfg.processed_dir + "training/query_id_to_expanded_query.pickle", "wb") as outfile:
    cPickle.dump(
        dict((query_id, expanded_query) for query_id, expanded_query in all_strategy_name_to_query_id_to_query[strategy_name].iteritems() if query_id in query_id_to_query),
        outfile,
        protocol=cPickle.HIGHEST_PROTOCOL
    )

In [None]:
%%time
with open(cfg.processed_dir + "training/query_id_to_base_query.pickle", "wb") as outfile:
    cPickle.dump(
        dict((query_id, expanded_query) for query_id, expanded_query in all_strategy_name_to_query_id_to_query["Base"].iteritems() if query_id in query_id_to_query),
        outfile,
        protocol=cPickle.HIGHEST_PROTOCOL
    )