# PREAMBLE

In [None]:
%load_ext cython
%matplotlib inline

In [None]:
import cfg

import collections
import copy
import cPickle
import heapq
import math
import numpy as np
import json
import os
import pandas as pd
import random
from scipy import stats
import xgboost
from scipy.stats import rankdata

import progress_bar as pb
import feature_extraction as fe

import efficient_query_expansion.index_cache as index_cache
from collection_stats.collection_stats_restricted import PyCollectionStatsRestricted

In [None]:
%%time
idx_cache_file_path = cfg.tmp_dir + "index_cache_dump.bin"
if os.path.isfile(idx_cache_file_path):
    idx_cache = index_cache.IndexCache.load(idx_cache_file_path)
    print "Found {} entries in cache".format(len(idx_cache))
else:
    idx_cache = index_cache.IndexCache(host="127.0.0.1", port=9001)
    print "Cache file not found"

# UTILITIES

In [None]:
%%cython

cimport cython
import numpy as np
cimport numpy as np

cdef extern from "math.h":
    double sqrt(double m)


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
@cython.cdivision(True)
cdef set _sets_union(set_iterator):
    res = set()
    for set_it in set_iterator:
        res.update(set_it)
    return res    


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
@cython.cdivision(True)
cdef set _sets_intersection(set_iterator):
    res = None
    for set_it in set_iterator:
        if res is None:
            res = set_it
        else:
            res.intersection_update(set_it)
    return res


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
@cython.cdivision(True)
def c_get_query_occurrences(query_repr, dict word_to_occurrence_set):
    return (
        _sets_union(
            _sets_intersection(
                _sets_union(
                    word_to_occurrence_set[word_and_tags[0]]
                    for word_and_tags in synset
                )
                for synset in and_query
            )
            for and_query in query_repr
        )
    )


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.nonecheck(False)
@cython.cdivision(True)
def c_get_num_match(dict qid_to_query, qid_set, dict qid_to_word_to_occurrence_set):
    if qid_set is None:
        qid_set = qid_to_query.keys()
    else:
        qid_set = set(qid_set)

    return dict(
        (qid, len(c_get_query_occurrences(qid_to_query[qid], qid_to_word_to_occurrence_set[qid])))
        for qid in qid_set
    )

In [None]:
def query_terms_iterator(query_repr, only_positions=None):
    if only_positions is None:
        abs_pos = -1
        and_pos = -1
        for and_query in query_repr:
            and_pos +=1
            syn_pos = -1
            for synset in and_query:
                syn_pos += 1
                term_pos = -1
                for term_tags in synset:
                    abs_pos += 1
                    term_pos += 1
                    yield (abs_pos, and_pos, syn_pos, term_pos, term_tags)
    else:
        assert hasattr(only_positions, "__iter__")
        only_positions = sorted(only_positions)
        assert len(only_positions) > 0
        i_stop = len(only_positions)
        i = 0

        abs_pos = -1
        and_pos = -1
        for and_query in query_repr:
            and_pos +=1
            syn_pos = -1
            for synset in and_query:
                syn_pos += 1
                term_pos = -1
                if only_positions[i] > abs_pos + len(synset):
                    abs_pos += len(synset)
                    continue
                for term_tags in synset:
                    abs_pos += 1
                    term_pos += 1
                    if only_positions[i] == abs_pos:
                        yield (abs_pos, and_pos, syn_pos, term_pos, term_tags)
                        i += 1
                        if i >= i_stop:
                            return
        raise Exception("One of the positions ({}) is out of the query".format(only_positions[i]))

In [None]:
def get_step_function(step, fun=None):
    assert isinstance(step, float) and step >= 0
    assert fun is None or (hasattr(fun, "__call__") and fun(0) == 1)

    if fun is None:
        return lambda v: 1 if v <= step else 0.0
    else:
        return lambda v: 1 if v <= step else fun(v-step)

def get_lin_decay_function(slope):
    assert isinstance(slope, float) and slope <= 0
    return lambda v: 1 + slope * v

def get_exp_decay_function(alpha):
    assert isinstance(alpha, float) and alpha <= 0
    return lambda v: math.exp(alpha * v)

def get_line_decay_slope_from_step(step, ratio_step):
    assert isinstance(step, float) and step > 0
    assert isinstance(ratio_step, float) and ratio_step > 0
    return -1.0 / (step * ratio_step)

def get_exp_decay_alpha_from_step(step, ratio_step):
    assert isinstance(step, float) and step > 0
    assert isinstance(ratio_step, float) and ratio_step > 0
    return -2.0 / (step * ratio_step)

In [None]:
def harmonic_mean(v1, v2, v1_weight=0.5):
    return (v1 * v2) / ((v1_weight * v2 + (1.0 - v1_weight) * v1) or 1.0)

def get_perf_to_harmonic_mean_function(fun1, fun2, fun1_weight=0.5):
    assert isinstance(fun1_weight, float) and 0 <= fun1_weight <= 1
    return lambda perf: harmonic_mean(fun1(perf), fun2(perf), fun1_weight)

def perf_to_recall(perf):
    return 1.0 * perf.num_rel_ret / perf.num_rel

def get_perf_to_exe_time_function(fun):
    return lambda perf: fun(perf.exe_time)

def get_perf_to_my_eet_function(step, step_ratio, effectivity_weight=0.5):
    assert isinstance(step, float) and step > 0
    assert isinstance(step_ratio, float) and step_ratio >= 0

    effectivity_fun = perf_to_recall
    efficiency_fun = get_perf_to_exe_time_function(get_step_function(step, fun=(None if step_ratio == 0 else get_exp_decay_function(get_exp_decay_alpha_from_step(step, step_ratio)))))

    return get_perf_to_harmonic_mean_function(effectivity_fun, efficiency_fun, effectivity_weight)

# READ THE DATASET

In [None]:
%%time
def jsonConvertKeys(constructor):
    return lambda x: {constructor(k):v for k,v in x.items()}

qid_to_query = json.load(open(cfg.training_dir + "qid_to_query.json", "r"), object_hook=jsonConvertKeys(int))
qid_to_docid_list = json.load(open(cfg.training_dir + "qid_to_docid_list.json", "r"), object_hook=jsonConvertKeys(int))

assert len(qid_to_query) == len(qid_to_docid_list)

In [None]:
qid_to_base_query = cPickle.load(open(cfg.processed_dir + "qid_to_base_query.pickle"))
qid_to_candidates = cPickle.load(open(cfg.processed_dir + "qid_to_candidates.pickle"))

assert set(qid_to_base_query.keys()) == set(qid_to_candidates.keys())

In [None]:
qid_to_num_candidates = dict(
    (qid, sum(
        len(synset)
        for and_query in candidates  # the query is the OR composition of different AND_QUERIES
        for synset in and_query  # the AND_QUERY is the AND composition of different SYNSET
    ))
    for qid, candidates in qid_to_candidates.iteritems()
)

In [None]:
%%time
queries_with_recall_improvement = cPickle.load(open(cfg.processed_dir + "queries_with_recall_improvement.pickle"))

In [None]:
%%time
qid_to_word_to_occurrence_set = cPickle.load(open(cfg.processed_dir + "qid_to_word_to_occurrence_set.pickle"))

In [None]:
def get_num_match(qid_to_query, qid_set=None):
    global qid_to_word_to_occurrence_set

    return c_get_num_match(qid_to_query, qid_set, qid_to_word_to_occurrence_set)

def get_num_match_query(qid, query_repr):
    global qid_to_word_to_occurrence_set

    return len(c_get_query_occurrences(query_repr, qid_to_word_to_occurrence_set[qid]))

In [None]:
# checks
for qid, candidates in qid_to_candidates.iteritems():
    base_query = qid_to_base_query[qid]

    # check the synonyms
    if any(syn.strip() == ""
           for and_query in candidates
           for synset in and_query
           for syn, tag in synset):
        raise AssertionError("One of the expansions of the query {} is empty".format(qid))

    # check the tags
    if any(tag is None
           for and_query in candidates
           for synset in and_query
           for syn, tag in synset):
        raise AssertionError("One of the tags of the query {} is None".format(qid))

    # check the base query
    if len(base_query[0]) == 0:
        raise AssertionError("The query {} contains an empty base query".format(qid))

    # check the shape of the base query and of the expansions
    if len(candidates) != len(base_query) or any(len(candidates[i]) != len(base_query[i]) for i in xrange(len(candidates))):
        raise AssertionError("The query {} has two different shapes for the base_query and candidates".format(qid))

In [None]:
print "Number of queries", len(qid_to_candidates)
print "Number of queries with zero expansions", sum(1 for qid, candidates in qid_to_candidates.iteritems()
                                                 if all(len(synset) <= 1 for and_query in candidates for synset in and_query))
print "Number of queries with improvements", len(queries_with_recall_improvement)

# READ COLLECTION STATISTICS

In [None]:
%%time
collection_stats = PyCollectionStatsRestricted.load(cfg.thesaurus_dir + "collection_stats_restricted.bin")

In [None]:
collection_stats_segment_to_segment_id = dict(
    (segment, segment_id)
    for segment_id, segment in cPickle.load(open(cfg.thesaurus_dir + "collection_stats_restricted_segmentid_to_segment.pickle", "rb")).iteritems()
)

In [None]:
print "NumTerms: {: >8}".format(collection_stats.get_num_terms())
print "CoOcc2:   {: >8}".format(collection_stats.get_num_term_pairs())
print "CoOcc3:   {: >8}".format(collection_stats.get_num_term_triples())

# FEATURES SUPPORT

In [None]:
def identity(np_array):
    return np_array

def normalize_range(np_array):
    min_array = np.nanmin(np_array, axis=0)
    divisor = np.nanmax(np_array, axis=0) - min_array
    divisor += (divisor == 0.0) * 1.0
    return (np_array - min_array) / divisor

def normalize_max(np_array):
    #assert np_array.min() >= 0
    divisor = np.nanmax(np_array, axis=0)
    divisor += (divisor == 0.0) * 1.0
    return np_array / divisor

def normalize_rank(np_array):
    np_array = np_array.T
    ranks = np.empty(np_array.shape)
    for r, row in enumerate(np_array):
        ranks[r] = rankdata(row, method="min")
    return ranks.T

def normalize_rank_descending(np_array):
    return normalize_rank(-np_array)

def standardize(np_array):
    #assert np_array.min() >= 0
    divisor = np.nanstd(np_array, axis=0)
    divisor += (divisor == 0.0) * 1.0
    return (np_array - np.nanmean(np_array, axis=0)) / divisor

In [None]:
%%time
tag_to_pos = dict(
    (tag, i)
    for i, tag in enumerate(sorted(set(
        tag
        for qid, exp_repr in qid_to_candidates.iteritems()
        for and_query in exp_repr
        for synset in and_query
        for term, tags in synset  # exclude the first term since it is the source and has not tags
        for tag in tags
    )))
)

In [None]:
scoring_featurizer = fe.FeatureComposer([
    fe.FeaturizerTextual(
        feature_name_prefix="TEXT "
    ),
    fe.FeaturizerTags(
        tag_to_pos.keys(),
        feature_name_prefix="TAG "
    ),
    fe.FeatureNormalizer(
        featurizer=fe.FeaturizerQueryPerformancePredictors(collection_stats, collection_stats_segment_to_segment_id),
        normalization_name_function_list=[
            ("", identity),
            ("NM ", normalize_max),
            ("NR ", normalize_range),
        ],
        feature_name_prefix="QPP "
    ),
    fe.FeatureNormalizer(
        featurizer=fe.FeaturizerSigIR08extended(collection_stats, collection_stats_segment_to_segment_id),
        normalization_name_function_list=[
            ("", identity),
            ("NM ", normalize_max),
            ("NR ", normalize_range),
        ],
        feature_name_prefix="SIGIRV2 "
    ),
])

# TRAINING SUPPORT

## MODEL

In [None]:
class Model(object):
    def predict(self, X):
        raise NotImplementedError()

In [None]:
class XGBModel(Model):
    def __init__(self, model):
        assert isinstance(model, xgboost.Booster)
        self._model = model
    
    def predict(self, X):
        ntree_limit = self._model.attributes().get("best_iteration", 0)
        return self._model.predict(xgboost.DMatrix(X), ntree_limit=ntree_limit)

In [None]:
class BinaryModel(Model):
    pass

In [None]:
class XGBBinaryClassifier(BinaryModel):
    def __init__(self, model, threshold):
        assert isinstance(model, xgboost.Booster)
        self._model = model
        self._threshold = threshold
    
    def predict(self, X):
        ntree_limit = self._model.attributes().get("best_iteration", 0)
        y = self._model.predict(xgboost.DMatrix(X), ntree_limit=ntree_limit)

        return y >= self._threshold

## BUILD TRAINING SET

In [None]:
def build_training_set(
    idx_cache,
    qid_list,
    perf_to_metric,
    features_featurizer,
    sequential_greedy_selection=True,
    scoring_model=None
):
    assert isinstance(sequential_greedy_selection, bool)
    assert scoring_model is None or isinstance(scoring_model, Model)

    global qid_to_base_query, qid_to_candidates

    with idx_cache.cursor() as cursor:
        dataset = dict()
        queries = dict()

        for qid in pb.iter_progress(qid_list):
            dataset[qid] = []
            queries[qid] = []

            num_positive_documents = len(qid_to_docid_list[qid])
            if num_positive_documents <= 0:
                continue

            # no improvements can be done, hence this query can be discarded from the training set
            base_num_matches = get_num_match_query(qid, qid_to_base_query[qid])
            if base_num_matches == num_positive_documents:
                continue

            curr_repr = copy.deepcopy(qid_to_base_query[qid])
            exp_repr = copy.deepcopy(qid_to_candidates[qid])
            num_candidates = qid_to_num_candidates[qid]

            doc_id_list = qid_to_docid_list[qid]
            w2o = qid_to_word_to_occurrence_set[qid]

            base_metric = perf_to_metric(cursor.get_performance(
                curr_repr,
                doc_id_list,
                qid,
                True
            ))

            X_list = []
            y_list = []

            while num_candidates > 0:
                X = features_featurizer.transform(curr_repr, exp_repr, num_candidates)
                y = np.zeros(num_candidates)
                if scoring_model is not None:
                    y_scoring = scoring_model.predict(X)
                    X = np.column_stack([X, y_scoring])

                for abs_pos, and_pos, syn_pos, term_pos, term_tags in query_terms_iterator(exp_repr):
                    curr_repr[and_pos][syn_pos].append(term_tags)
                    if len(w2o[term_tags[0]] - w2o[curr_repr[and_pos][syn_pos][0][0]]) == 0:
                        y[abs_pos] = 0
                    elif not(base_num_matches < get_num_match_query(qid, curr_repr) >= 2):
                        y[abs_pos] = 0
                    else:
                        y[abs_pos] = max(
                            0.0,
                            perf_to_metric(cursor.get_performance(
                                curr_repr,
                                doc_id_list,
                                qid,
                                True
                            )) - base_metric
                        )
                    curr_repr[and_pos][syn_pos].pop()

                dataset[qid].append((X,y))

                # get the best term according to the oracle or according to the scoring model
                if scoring_model is None:
                    best_abs_pos = np.argmax(y)
                else:
                    best_abs_pos = np.argmax(y_scoring)

                # get the score and the position of the best term
                best_score = y[best_abs_pos]
                best_tpl = None
                for tpl in query_terms_iterator(exp_repr, only_positions=[best_abs_pos]):
                    best_tpl = tpl
                assert best_tpl is not None

                # when to stop the sequential_greedy_selection
                if best_score <= 0:
                    break

                # update the current representation
                abs_pos, and_pos, syn_pos, term_pos, term_tags = best_tpl
                curr_repr[and_pos][syn_pos].append(term_tags)
                exp_repr[and_pos][syn_pos].pop(term_pos)
                num_candidates -= 1

                base_num_matches = get_num_match_query(qid, curr_repr)
                base_metric = y[abs_pos] + base_metric

                queries[qid].append([[[(term_tags[0],) for term_tags in synset] for synset in and_query] for and_query in curr_repr])

                if not sequential_greedy_selection:
                    break

    return dataset, queries

In [None]:
def dump_training_set(filename, dataset, queries):
    with open(filename, "w") as outfile:
        cPickle.dump(queries, outfile, protocol=cPickle.HIGHEST_PROTOCOL)
        for qid, Xy_list in pb.iteritems_progress(dataset):
            cPickle.dump((qid, len(Xy_list)), outfile, protocol=cPickle.HIGHEST_PROTOCOL)
            for X, y in Xy_list:
                np.save(outfile, X)
                np.save(outfile, y)

In [None]:
def read_dataset_set(filename):
    with open(filename) as infile:
        oracle = cPickle.load(infile)
        dataset = dict()
        for i in pb.iter_progress(xrange(len(oracle)), size=len(oracle)):
            qid, num_Xy = cPickle.load(infile)
            dataset[qid] = []
            for j in xrange(num_Xy):
                dataset[qid].append((np.load(infile), np.load(infile)))
    return dataset, oracle

## TRAINING

In [None]:
seed = 0

full_qid_list = sorted(queries_with_recall_improvement)

random.seed(seed)
random.shuffle(full_qid_list)

c1 = len(full_qid_list) * 70 / 100
c2 = len(full_qid_list) * 85 / 100

train_qid_list = full_qid_list[:c1]
valid_qid_list = full_qid_list[c1:c2]
test_qid_list = full_qid_list[c2:]

In [None]:
def dataset_iterator(raw_dataset, qid_list, remove_last_step=True, only_first_step=False):
    assert all(qid in raw_dataset for qid in qid_list)
    for qid in qid_list:
        query_dataset = raw_dataset[qid]
        max_i = len(query_dataset)
        if remove_last_step:
            max_i -= 1
        if only_first_step:
            max_i = min(1, max_i)
        for i in xrange(max_i):
            yield qid, i, query_dataset[i][0], query_dataset[i][1]  # which is qid, step, X, y

In [None]:
def get_top_k_arg(k, vec):
    return vec.argsort()[-k:][::-1]  # faster
    #return heapq.nlargest(k, np.arange(vec.size), key=(lambda p: vec[p]))

In [None]:
def get_feval_gain_at_k(k, xgmatrix_to_groups):
    assert isinstance(k, int) and k > 0
    assert isinstance(xgmatrix_to_groups, dict)
    assert all(isinstance(key, xgboost.DMatrix) and isinstance(value, np.ndarray) for key, value in xgmatrix_to_groups.iteritems())

    metric_name = "gain@{}".format(k)
    def feval_gain_at_k(preds, dtrain):
        groups = xgmatrix_to_groups.get(dtrain, None)
        assert groups is not None
        labels = dtrain.get_label()

        metric_sum = 0.0

        l, r = 0, 0
        for g in groups:
            r = l + g
            labels_bests = get_top_k_arg(k, labels[l:r])
            pred_bests = get_top_k_arg(k, preds[l:r])
            metric_sum += 1.0 * labels[l:r][pred_bests].sum() / (labels[l:r][labels_bests].sum() or 1.0)
            l = r

        return (metric_name, float(metric_sum) / len(groups))

    return feval_gain_at_k

## SCORING

In [None]:
def get_scoring_training(raw_dataset, qid_list, sequential_greedy_selection, transform_fun=None):
    assert hasattr(qid_list, "__iter__") and all(qid in raw_dataset for qid in qid_list)
    assert isinstance(sequential_greedy_selection, bool)
    assert transform_fun is None or hasattr(transform_fun, "__call__")

    kwargs = {
        "raw_dataset": raw_dataset,
        "qid_list": qid_list,
        "remove_last_step": True,
        "only_first_step": (not sequential_greedy_selection)
    }

    new_X = []
    new_y = []
    new_weights = []
    new_groups = []
    for qid, step, X, y in pb.iter_progress(dataset_iterator(**kwargs)):
        if transform_fun:
            X, y, w = transform_fun(X, y)
        else:
            y = np.clip(y, 0, 1)
            w = np.ones(y.shape)

        new_X.append(X)
        new_y.append(y)
        new_weights.append(w)
        new_groups.append(X.shape[0])
    
    new_X = np.concatenate(new_X, axis=0)
    new_y = np.concatenate(new_y)
    new_weights = np.concatenate(new_weights)

    xgmatrix = xgboost.DMatrix(
        data=new_X,
        label=new_y,
        weight=new_weights,
    )
    xgmatrix.set_group(new_groups)
    return xgmatrix, new_groups

## BASELINE - StaticRecall

In [None]:
recall_dataset, recall_oracle = build_training_set(
    idx_cache,
    queries_with_recall_improvement,
    perf_to_recall,
    scoring_featurizer,
    sequential_greedy_selection=False
)

In [None]:
dump_training_set(cfg.tmp_dir + "{}_dataset.pickle".format("recall"), recall_dataset, recall_oracle)

In [None]:
partition_to_xgmatrix = dict()
xgmatrix_to_groups = dict()

for what, qid_list in [("train", train_qid_list), ("valid", valid_qid_list)]:
    new_X = []
    new_y = []
    new_groups = []
    for qid in pb.iter_progress(qid_list, labeling_fun={"prefix": what}):
        y = recall_dataset[qid][0][1]
        if y.max <= 0:
            continue
        new_X.append(recall_dataset[qid][0][0])
        new_y.append(10.0 * np.clip(y, 0, 1))
        new_groups.append(y.size)

    xgmatrix = xgboost.DMatrix(
        np.concatenate(new_X, axis=0),
        label=np.concatenate(new_y, axis=0)
    )
    xgmatrix.set_group(new_groups)
    partition_to_xgmatrix[what] = xgmatrix
    xgmatrix_to_groups[xgmatrix] = np.array(new_groups)

In [None]:
%%time
baseline_scoring_model = xgboost.train(
    params={
        'objective': 'rank:pairwise',
        'eval_metric': 'ndcg@1',
        'max_depth': 6,
        'eta': 0.1,
        'silent': 0,
    },
    num_boost_round=300,
    dtrain=partition_to_xgmatrix["train"],
    evals=[(partition_to_xgmatrix["train"], "training"), (partition_to_xgmatrix["valid"], 'validation')],
    early_stopping_rounds=20,
    feval=get_feval_gain_at_k(5, xgmatrix_to_groups),
    maximize=True
)
# [210]	training-ndcg@1:0.741305	validation-ndcg@1:0.737976	training-gain@5:0.652657	validation-gain@5:0.620137

In [None]:
baseline_scoring_model.save_model(cfg.tmp_dir + "scoring_recall_static.model")

In [None]:
del partition_to_xgmatrix, xgmatrix_to_groups

In [None]:
del recall_dataset, recall_oracle

## STATIC VS GREEDY EET200

In [None]:
eet200_dataset, eet200_oracle = build_training_set(
    idx_cache,
    queries_with_recall_improvement,
    #get_perf_to_my_eet_function(200.0, 1.0),
    perf_to_recall,
    scoring_featurizer,
    sequential_greedy_selection=True
)

In [None]:
dump_training_set(cfg.tmp_dir + "{}_dataset.pickle".format("eet200"), eet200_dataset, eet200_oracle)

In [None]:
partition_to_xgmatrix = dict()
xgmatrix_to_groups = dict()
for what, qid_list in [("train", train_qid_list), ("valid", valid_qid_list)]:
    for greedy in [True, False]:
        name = "{} {}".format(what, "greedy" if greedy else "static")
        xgmatrix, groups = get_scoring_training(eet200_dataset, qid_list, sequential_greedy_selection=greedy)
        partition_to_xgmatrix[name] = xgmatrix
        xgmatrix_to_groups[xgmatrix] = np.array(groups)
        del xgmatrix, groups

### STATIC

In [None]:
%%time
scoring_model_static = xgboost.train(
    params={
        'objective': 'rank:pairwise',
        'eval_metric': 'ndcg@5',
        'max_depth': 6,
        'eta': 0.1,
        'silent': 0
    },
    num_boost_round=300,
    dtrain=partition_to_xgmatrix["train static"],
    evals=[(partition_to_xgmatrix["train static"], "training"), (partition_to_xgmatrix["valid static"], 'validation')],
    early_stopping_rounds=20,
    feval=get_feval_gain_at_k(5, xgmatrix_to_groups),
    maximize=True
)
# [252]	training-ndcg@5:0.999804	validation-ndcg@5:0.999803	training-gain@5:0.673998	validation-gain@5:0.637472

In [None]:
%time scoring_model_static.save_model(cfg.tmp_dir + "scoring_eet200_static.model")

### GREEDY

In [None]:
%%time
scoring_model_greedy = xgboost.train(
    params={
        'objective': 'rank:pairwise',
        'eval_metric': 'ndcg@1',
        'max_depth': 6,
        'eta': 0.1,
        'silent': 0,
    },
    num_boost_round=300,
    dtrain=partition_to_xgmatrix["train greedy"],
    evals=[(partition_to_xgmatrix["train greedy"], "training"), (partition_to_xgmatrix["valid greedy"], 'validation')],
    early_stopping_rounds=20,
    feval=get_feval_gain_at_k(1, xgmatrix_to_groups),
    maximize=True
)
# 0.469493 in 182
# [171]	training-ndcg@1:0.999741	validation-ndcg@1:0.999695	training-gain@1:0.5059	validation-gain@1:0.481192

In [None]:
%time scoring_model_greedy.save_model(cfg.tmp_dir + "scoring_eet200_greedy.model")

In [None]:
del partition_to_xgmatrix, xgmatrix_to_groups

# PRUNING

In [None]:
scoring_model_static = xgboost.Booster()
scoring_model_static.load_model(cfg.tmp_dir + "scoring_eet200_static.model")
print scoring_model_static.attributes()

In [None]:
scoring_model_greedy = xgboost.Booster()
scoring_model_greedy.load_model(cfg.tmp_dir + "scoring_eet200_greedy.model")
print scoring_model_greedy.attributes()

In [None]:
raw_dataset_static, raw_queries_static = build_training_set(
    idx_cache,
    full_qid_list,
    get_perf_to_my_eet_function(200.0, 1.0),
    scoring_featurizer,
    sequential_greedy_selection=False,
    scoring_model=XGBModel(scoring_model_static)
)

In [None]:
dump_training_set(cfg.tmp_dir + "{}_pruning_static_dataset.pickle".format("eet200"), raw_dataset_static, raw_queries_static)

In [None]:
raw_dataset_greedy, raw_queries_greedy = build_training_set(
    idx_cache,
    full_qid_list,
    get_perf_to_my_eet_function(200.0, 1.0),
    scoring_featurizer,
    sequential_greedy_selection=True,
    scoring_model=XGBModel(scoring_model_greedy)
)

In [None]:
dump_training_set(cfg.tmp_dir + "{}_pruning_greedy_dataset.pickle".format("eet200"), raw_dataset_greedy, raw_queries_greedy)

In [None]:
%%time
print len(idx_cache)
idx_cache.dump(idx_cache_file_path)

In [None]:
%%cython

import numpy as np
cimport numpy as np
from numpy cimport ndarray

def find_threshold(
    ndarray[np.float32_t, ndim=1] y_true,
    ndarray[np.float32_t, ndim=1] y_pred
):
    cdef size_t best_p = 0
    cdef size_t score = (y_true > 0).sum()
    cdef size_t best_score = score
    cdef size_t p

    for p in np.argsort(y_pred):
        if y_true[p] > 0.0:
            score -= 1
        else:
            score += 1
            if score > best_score:
                best_score = score
                best_p = p
    
    assert score == (y_true <= 0).sum()

    return y_pred[best_p]

In [None]:
def get_pruning_training(raw_dataset, qid_list, greedy):
    assert hasattr(qid_list, "__iter__") and all(qid in raw_dataset for qid in qid_list)
    assert isinstance(greedy, bool)

    kwargs = {
        "raw_dataset": raw_dataset,
        "qid_list": qid_list,
        "remove_last_step": False,
        "only_first_step": (not greedy)
    }

    new_X = []
    new_y = []
    new_weights = []
    new_groups = []
    for qid, step, X, y in pb.iter_progress(dataset_iterator(**kwargs)):
        best_positions = get_top_k_arg(1 if greedy else 5, X[:,-1])
        num_steps_qid = len(raw_dataset[qid])

        X = X[best_positions]
        y = y[best_positions]
        w = np.full(len(best_positions), 1)

        new_X.append(X)
        new_y.append(y)
        new_weights.append(w)
        new_groups.append(X.shape[0])

    xgmatrix = xgboost.DMatrix(
        data=np.concatenate(new_X, axis=0),
        label=np.concatenate(new_y),
        weight=np.concatenate(new_weights),
    )
    xgmatrix.set_group(new_groups)
    return xgmatrix, new_groups

In [None]:
partition_to_xgmatrix = dict()
xgmatrix_to_groups = dict()
for what, qid_list in [("train", train_qid_list), ("valid", valid_qid_list)]:
    for greedy in [True, False]:
        name = "{} {}".format(what, "greedy" if greedy else "static")
        xgmatrix, groups = get_pruning_training(raw_dataset_greedy if greedy else raw_dataset_static, qid_list, greedy=greedy)
        partition_to_xgmatrix[name] = xgmatrix
        xgmatrix_to_groups[xgmatrix] = np.array(groups)
        del xgmatrix, groups

### STATIC

In [None]:
%%time
pruning_model_static = xgboost.train(
    params={
        'objective': 'reg:linear',
        'eval_metric': 'rmse',
        'max_depth': 6,
        'eta': 0.1,
#        'base_score': 0.5,
        'scale_pos_weight': 0.8,  # should be something like sum(negative cases) / sum(positive cases)
        'silent': 0
    },
    num_boost_round=200,
    dtrain=partition_to_xgmatrix["train static"],
    evals=[(partition_to_xgmatrix["train static"], "training"), (partition_to_xgmatrix["valid static"], 'validation')],
    early_stopping_rounds=20,
    maximize=False
)
# [145]	training-rmse:0.063551	validation-rmse:0.088251

In [None]:
%time pruning_model_static.save_model(cfg.tmp_dir + "pruning_{}_static.model".format("eet_200"))

In [None]:
y_true = partition_to_xgmatrix["train static"].get_label()
y_pred = pruning_model_static.predict(
    partition_to_xgmatrix["train static"],
    ntree_limit=pruning_model_static.attr("best_iteration")
)

pruning_threshold_static = find_threshold(y_true, y_pred)

1.0 * ((y_pred >= pruning_threshold_static) == (y_true > 0)).sum() / y_true.size

In [None]:
print pruning_threshold_static

### GREEDY

In [None]:
%%time
pruning_model_greedy = xgboost.train(
    params={
        'objective': 'reg:linear',
        'eval_metric': 'rmse',
        'max_depth': 6,
        'eta': 0.09,
#        'base_score': 0.5,
        'scale_pos_weight': 0.7,  # should be something like sum(negative cases) / sum(positive cases)
        'silent': 0
    },
    num_boost_round=200,
    dtrain=partition_to_xgmatrix["train greedy"],
    evals=[(partition_to_xgmatrix["train greedy"], "training"), (partition_to_xgmatrix["valid greedy"], 'validation')],
    early_stopping_rounds=20,
    maximize=False
)
# 0.086712 with 102
# [112]	training-rmse:0.061863	validation-rmse:0.088963

In [None]:
%time pruning_model_greedy.save_model(cfg.tmp_dir + "pruning_{}_greedy.model".format("eet_200"))

In [None]:
y_true = partition_to_xgmatrix["train greedy"].get_label()
y_pred = pruning_model_greedy.predict(
    partition_to_xgmatrix["train greedy"],
    ntree_limit=pruning_model_greedy.attr("best_iteration")
)

pruning_threshold_greedy = find_threshold(y_true, y_pred)

1.0 * ((y_pred >= pruning_threshold_greedy) == (y_true > 0)).sum() / y_true.size

In [None]:
print pruning_threshold_greedy

In [None]:
del partition_to_xgmatrix, xgmatrix_to_groups

In [None]:
del raw_dataset_static, raw_dataset_greedy

# FEATURE SCORE

In [None]:
name_model_list = [
    ("S2_Recall", baseline_scoring_model),
    ("S2_EET", scoring_model_static),
    ("SGS_EET", scoring_model_greedy),
    ("SGS+Pruning_EET", pruning_model_static),
    ("SGS+Pruning_EET", pruning_model_greedy),
]

features = list(scoring_featurizer.feature_names()) + ["scoring"]
features_scores = []
for name, model in name_model_list:
    print name
    scores = np.zeros(len(features))
    for feature_name, score in model.get_fscore().iteritems():
        scores[int(feature_name[1:])] = score
    features_scores.append(scores)
    del scores

df = pd.DataFrame(data=np.array(features_scores).T, index=features, columns=[name for name, _ in name_model_list])
df.sort_values(by="S2_Recall", inplace=True, ascending=False)
df

# APPLY MODEL

In [None]:
def apply_model_static_greedy(
    scoring_featurizer,
    scoring_model,
    pruning_model,
    qid,
    num_terms=1,
    greedy=True,
    compute_unpruned=True,
):
    assert isinstance(scoring_model, Model)
    assert (pruning_model is None) or isinstance(pruning_model, BinaryModel)
    assert int(num_terms) > 0

    num_terms = int(num_terms)

    result_list_pruned = []
    result_list_un_pruned = []

    exp_repr = qid_to_candidates[qid]
    num_candidates = qid_to_num_candidates[qid]

    if not greedy:
        base_repr = qid_to_base_query[qid]

        # save the STATIC repr
        def _save(target_list, positions, y_pruning):
            static_repr = base_repr
            for i, abs_pos in enumerate(positions):
                if y_pruning[i]:
                    # clone the query representation
                    static_repr = copy.deepcopy(static_repr)
                    # find the term in the query
                    best_tpl = None
                    for tpl in query_terms_iterator(exp_repr, only_positions=[abs_pos]):
                        best_tpl = tpl
                    assert best_tpl is not None and abs_pos == best_tpl[0] 
                    # add the term inside the query representation
                    abs_pos, and_pos, syn_pos, term_pos, term_tags = best_tpl
                    static_repr[and_pos][syn_pos].append(term_tags)
                target_list.append(static_repr)

        X = scoring_featurizer.transform(base_repr, exp_repr, num_candidates)
        y_scor = scoring_model.predict(X)

        best_positions = y_scor.argsort()[-num_terms:][::-1]
        if compute_unpruned:
            _save(result_list_un_pruned, best_positions, np.ones(best_positions.size))

        if pruning_model:
            X = np.column_stack([X[best_positions], y_scor[best_positions]])
            y_post = pruning_model.predict(X)

            _save(result_list_pruned, best_positions, y_post)
    else:
        greedy_repr = copy.deepcopy(qid_to_base_query[qid])
        exp_repr = copy.deepcopy(exp_repr)
        step = 0
        pruned = False

        while num_candidates > 0:
            X = scoring_featurizer.transform(greedy_repr, exp_repr, num_candidates)
            y_scor = scoring_model.predict(X)
            step += 1

            # save the GREEDY repre
            y_argmax = y_scor.argmax()
            best_tpl = None
            for best_tpl in query_terms_iterator(exp_repr, only_positions=[y_argmax]):
                pass
            assert best_tpl is not None
            best_positions = [best_tpl[0]]

            # pruning
            if pruning_model:
                X = np.column_stack([X[best_positions], y_scor[best_positions]])
                y = pruning_model.predict(X)
                if not y[0]:
                    pruned = True

            # update the status
            abs_pos, and_pos, syn_pos, term_pos, term_tags = best_tpl
            greedy_repr[and_pos][syn_pos].append(term_tags)
            exp_repr[and_pos][syn_pos].pop(term_pos)

            if not pruned:
                result_list_pruned.append(copy.deepcopy(greedy_repr))
            if compute_unpruned:
                result_list_un_pruned.append(copy.deepcopy(greedy_repr))

            if pruned and not compute_unpruned:
                break

            if step >= num_terms:
                break
            num_candidates -= 1

    return (result_list_pruned, result_list_un_pruned)

In [None]:
terms_range = [1,3,5]
max_terms_range = max(terms_range)

In [None]:
representations = collections.OrderedDict()
representations["NoEXP"] = dict(
    (qid, qid_to_base_query[qid])
    for qid in test_qid_list
)

In [None]:
for num_terms in terms_range:
    representations["S2 Recall [{}]".format(num_terms)] = dict()
    representations["S2 EET [{}]".format(num_terms)] = dict()
    representations["S2 + Pruning EET [{}]".format(num_terms)] = dict()
    representations["SGS EET [{}]".format(num_terms)] = dict()
    representations["SGS + Pruning EET [{}]".format(num_terms)] = dict()

In [None]:
def _get_repr(qid, repr_list, num_terms):
    _pos = num_terms - 1
    if num_terms == 0:
        return qid_to_base_query[qid]
    _len = len(repr_list)
    if _pos < _len:
        return repr_list[_pos]
    if _len > 0:
        return repr_list[-1]
    return qid_to_base_query[qid]

In [None]:
for num_terms in terms_range:
    representations["Oracle SGS + Pruning EET [{}]".format(num_terms)] = dict(
        (qid, _get_repr(qid, eet200_oracle[qid], num_terms))
        for qid in test_qid_list
    )

In [None]:
for qid in pb.iter_progress(test_qid_list):
    pruned_reprs, unpruned_reprs = apply_model_static_greedy(
        scoring_featurizer,
        scoring_model=XGBModel(baseline_scoring_model),
        pruning_model=None,
        qid=qid,
        num_terms=max_terms_range,
        greedy=False
    )
    for num_terms in terms_range:
        representations["S2 Recall [{}]".format(num_terms)][qid] = _get_repr(qid, unpruned_reprs, num_terms)

In [None]:
for qid in pb.iter_progress(test_qid_list):
    pruned_reprs, unpruned_reprs = apply_model_static_greedy(
        scoring_featurizer,
        scoring_model=XGBModel(scoring_model_static),
        pruning_model=XGBBinaryClassifier(pruning_model_static, pruning_threshold_static),
        qid=qid,
        num_terms=max_terms_range,
        greedy=False
    )
    for num_terms in terms_range:
        representations["S2 EET [{}]".format(num_terms)][qid] = _get_repr(qid, unpruned_reprs, num_terms)
        representations["S2 + Pruning EET [{}]".format(num_terms)][qid] = _get_repr(qid, pruned_reprs, num_terms)

In [None]:
for qid in pb.iter_progress(test_qid_list):
    pruned_reprs, unpruned_reprs = apply_model_static_greedy(
        scoring_featurizer,
        scoring_model=XGBModel(scoring_model_greedy),
        pruning_model=XGBBinaryClassifier(pruning_model_greedy, pruning_threshold_greedy),
        qid=qid,
        num_terms=max_terms_range,
        greedy=True
    )
    for num_terms in terms_range:
        representations["SGS EET [{}]".format(num_terms)][qid] = _get_repr(qid, unpruned_reprs, num_terms)
        representations["SGS + Pruning EET [{}]".format(num_terms)][qid] = _get_repr(qid, pruned_reprs, num_terms)

In [None]:
representations["Oracle SGS + Pruning EET [{}]"] = dict(
    (qid, _get_repr(qid, eet200_oracle[qid], 100))
    for qid in test_qid_list
)

In [None]:
performances = collections.OrderedDict()

In [None]:
with idx_cache.cursor() as cursor:
    for what, qid2repr in representations.iteritems():
        if len(qid2repr) == 0:
            continue
        if what not in performances:
            performances[what] = dict()

        for qid, query_repr in pb.iteritems_progress(qid2repr, labeling_fun={"prefix": what}, every=10):
#            if qid in performances[what]:
#                continue
            performances[what][qid] = cursor.get_performance(query_repr, qid_to_docid_list[qid], qid)

In [None]:
%%time
print len(idx_cache)
idx_cache.dump(idx_cache_file_path)

### SHOW RESULTS TEST

In [None]:
terms_range = [1,3,5]#range(1, 3+1)
max_terms_range = max(terms_range)

In [None]:
columns = []
table = collections.OrderedDict((name, []) for name in performances)

columns.append("Exe Time")
for name, qid2perf in performances.iteritems():
    tmp = [qid2perf[qid].exe_time for qid in test_qid_list]
    table[name].append("{: >6.1f} ± {: >6.1f}".format(
        np.average(tmp),
        np.std(tmp)
    ))

columns.append("Recall")
for name, qid2perf in performances.iteritems():
    tmp = [100.0*perf_to_recall(qid2perf[qid]) for qid in test_qid_list]
    table[name].append("{: >4.2f}% ± {: >4.1f}".format(
        np.average(tmp),
        np.std(tmp)
    ))

columns.append("EET(200,1.0)")
f = get_perf_to_my_eet_function(200.0, 1.0)
for name, qid2perf in performances.iteritems():
    tmp = [100.0*f(qid2perf[qid]) for qid in test_qid_list]
    table[name].append("{: >4.2f}% ± {: >4.1f}".format(
        np.average(tmp),
        np.std(tmp)
    ))

columns.append("#terms")
for name, qid2perf in performances.iteritems():
    def query_to_num_exp(query):
        return sum(1 for tpl in query_terms_iterator(query) if tpl[3] != 0)  # tpl[3] is the term_pos inside the synset
    tmp = [query_to_num_exp(representations[name][qid]) for qid in test_qid_list]
    table[name].append("{:.1f} ± {:.1f}".format(
        np.average(tmp),
        np.std(tmp)
    ))

df = pd.DataFrame(table.values(), table.keys(), columns)
df.loc[
    ["NoEXP"] + [
        template.format(num_terms)
        for num_terms in terms_range
        for template in [
            "S2 Recall [{}]",
            "S2 EET [{}]",
            "SGS EET [{}]",
            "S2 + Pruning EET [{}]",
            "SGS + Pruning EET [{}]",
            "Oracle SGS + Pruning EET [{}]"
        ]
    ]
]