From e6152b303fe5948da3f91f3384d8fb256b4e4d74 Mon Sep 17 00:00:00 2001 From: Thac-Thong Nguyen Date: Thu, 27 May 2021 13:14:58 +0200 Subject: [PATCH] Switch to argparse (#3) * switch to argparse for more standard CLI arguments * new measures: Pearson rank, Spearman, and KL divergence --- .github/workflows/python-package.yml | 24 +++++ diffir/__init__.py | 4 + diffir/batchrun.py | 10 +- diffir/measure/__init__.py | 33 +++---- diffir/measure/qrels.py | 44 +++++---- diffir/measure/unsupervised.py | 62 ++++++------ diffir/run.py | 136 +++++++++++++-------------- diffir/templates/template.html | 2 +- diffir/test/test_tauap.py | 55 +++++++++++ diffir/weight/__init__.py | 13 +-- diffir/weight/custom.py | 28 +++--- diffir/weight/unsupervised.py | 22 ++--- requirements.txt | 6 +- 13 files changed, 256 insertions(+), 183 deletions(-) create mode 100644 .github/workflows/python-package.yml create mode 100644 diffir/test/test_tauap.py diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml new file mode 100644 index 0000000..143210a --- /dev/null +++ b/.github/workflows/python-package.yml @@ -0,0 +1,24 @@ +name: test +on: [push] +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.6, 3.7, 3.8] + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest + pip install -r requirements.txt + - name: Test with pytest + run: | + export PYTHONPATH=${PYTHONPATH}:/home/runner/work/capreolus/diffir + pytest + diff --git a/diffir/__init__.py b/diffir/__init__.py index 419ff56..20912cd 100644 --- a/diffir/__init__.py +++ b/diffir/__init__.py @@ -1,4 +1,8 @@ __version__ = "0.1.0" from diffir.weight import Weight +from diffir.weight.custom import CustomWeight +from diffir.weight.unsupervised import ExactMatchWeight from diffir.measure import Measure +from diffir.measure.qrels import QrelMeasure +from diffir.measure.unsupervised import TopkMeasure diff --git a/diffir/batchrun.py b/diffir/batchrun.py index 7bc2bc9..2e664ac 100644 --- a/diffir/batchrun.py +++ b/diffir/batchrun.py @@ -29,9 +29,13 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("directory") parser.add_argument("-o", "--output", dest="output_dir") - parser.add_argument("--config", dest="config", nargs="*") - + parser.add_argument("--dataset", dest="dataset", type=str, help="dataset from ir_datasets") + parser.add_argument("--measure", dest="measure", type=str, default="tauap", help="measure for ranking difference (qrel, tauap,weightedtau)") + parser.add_argument("--metric", dest="metric", type=str, default="MAP", help="metric used with qrel measure") + parser.add_argument("--topk", dest="topk", type=int, default=10) args = parser.parse_args() + config = {"dataset": args.dataset, "measure": args.measure, "metric": args.metric, "topk": args.topk, + "weight": {"weights_1": None, "weights_2": None}} indir = Path(args.directory) output = Path(args.output_dir) if args.output_dir else indir / "diffir" output.mkdir(exist_ok=True) @@ -50,7 +54,7 @@ def main(): single_runs = sorted(single_runs) # sorted needed for itertools ordering queue = [(fn,) for fn in single_runs] + list(itertools.combinations(single_runs, 2)) - f = partial(process_runs, config=args.config, output=output) + f = partial(process_runs, config=config, output=output) with multiprocessing.Pool(8) as p: outdirs = p.map(f, queue) diff --git a/diffir/measure/__init__.py b/diffir/measure/__init__.py index 01ef700..4a947fa 100644 --- a/diffir/measure/__init__.py +++ b/diffir/measure/__init__.py @@ -1,26 +1,27 @@ -from profane import ModuleBase, import_all_modules, ConfigOption +class Measure: + def __init__(self, metric="ndcg_20", topk=5): + ''' + Measure construction + :param metric: The metric used for selecting queries. + :param topk: How many queries to retrieve + ''' + self.metric = metric + self.topk = topk - -class Measure(ModuleBase): - module_type = "measure" - config_spec = [ - ConfigOption(key="metric", default_value="ndcg_20", description="The metric to use for selecting queries"), - ConfigOption(key="topk", default_value=5, description="How many queries to retrieve"), - ] - - # TODO finalize API def query_differences(self, run1, run2, *args, **kwargs): + ''' + :param run1: the first run + :param run2: the second run + :param args: + :param kwargs: + :return: + ''' if run1 and run2: return self._query_differences(run1, run2, *args, **kwargs) elif run1 and run2 is None: - qids = sorted(list(run1.keys()))[: self.config["topk"]] + qids = sorted(list(run1.keys()))[: self.topk] id2diff = {qid: 0 for qid in qids} return qids, id2diff, "singlerun" def _query_differences(self, run1, run2, *args, **kwargs): raise NotImplementedError - - -# TODO this is going to break once we introduce optional modules. need a way for them to fail gracefully. -# or to enumerate/register them without importing the py file? -import_all_modules(__file__, __package__) diff --git a/diffir/measure/qrels.py b/diffir/measure/qrels.py index da7694d..5113b6a 100644 --- a/diffir/measure/qrels.py +++ b/diffir/measure/qrels.py @@ -1,17 +1,10 @@ -import pytrec_eval -from profane import ModuleBase, Dependency, ConfigOption +from ir_measures import iter_calc, parse_measure +import sys from diffir.measure import Measure -@Measure.register class QrelMeasure(Measure): module_name = "qrel" - - config_spec = [ - ConfigOption(key="topk", default_value=10, description="The number of differing queries to return"), - ConfigOption(key="metric", default_value="ndcg_cut_20", description="TODO"), - ] - def _query_differences(self, run1, run2, *args, **kwargs): """ :param run1: TREC run. Has the format {qid: {docid: score}, ...} @@ -28,13 +21,32 @@ def _query_differences(self, run1, run2, *args, **kwargs): run2 = {qid: doc_id_to_score for qid, doc_id_to_score in run2.items() if qid in overlapping_keys} qrels = dataset.qrels_dict() - metric = self.config["metric"] - topk = self.config["topk"] - evaluator = pytrec_eval.RelevanceEvaluator(qrels, {metric}) - eval_run_1 = evaluator.evaluate(run1) - eval_run_2 = evaluator.evaluate(run2) + try: + metric = parse_measure(self.metric) + except NameError: + print("Unknown measure: {}. Please provide a measure supported by https://ir-measur.es/".format(self.metric)) + sys.exit(1) + + topk = self.topk + eval_run_1 = self.convert_to_nested_dict(iter_calc([metric], qrels, run1)) + eval_run_2 = self.convert_to_nested_dict(iter_calc([metric], qrels, run2)) + query_ids = eval_run_1.keys() & eval_run_2.keys() query_ids = sorted(query_ids, key=lambda x: abs(eval_run_1[x][metric] - eval_run_2[x][metric]), reverse=True) query_ids = query_ids[:topk] - id2diff = {x:abs(eval_run_1[x][metric] - eval_run_2[x][metric]) for x in query_ids} - return query_ids, id2diff, metric + id2diff = {x: abs(eval_run_1[x][metric] - eval_run_2[x][metric]) for x in query_ids} + id2qrelscores = {x: [eval_run_1[x][metric], eval_run_2[x][metric]] for x in query_ids} + return query_ids, id2diff, self.metric, id2qrelscores + + def convert_to_nested_dict(self, ir_measures_iterator): + """ + Util method to convert the results from ir_measures.iter_calc to a dict. + TODO: We can probably refactor so that this method won't be needed + """ + eval_dict = {} + + for x in ir_measures_iterator: + # TODO: This assumes that there would be only one measure/metric to handle. + eval_dict[x.query_id] = {x.measure: x.value} + + return eval_dict diff --git a/diffir/measure/unsupervised.py b/diffir/measure/unsupervised.py index 9a7af69..7a2d584 100644 --- a/diffir/measure/unsupervised.py +++ b/diffir/measure/unsupervised.py @@ -1,17 +1,10 @@ -from profane import ModuleBase, Dependency, ConfigOption from diffir.measure import Measure from scipy import stats import numpy as np -import math -@Measure.register class TopkMeasure(Measure): module_name = "topk" - config_spec = [ - ConfigOption(key="topk", default_value=3, description="TODO"), - ConfigOption(key="metric", default_value="weightedtau", description="Metric to measure the rank correaltion"), - ] def tauap(self, x, y, decreasing=True): """ @@ -47,8 +40,8 @@ def tauap_fast(self, x, y): n = len(ry) if n == 1: return 1 - ordered_idx = sorted(list(range(n)), key=lambda i: rx[i]) - ry_ordered_by_rx = [(ry[idx], i) for i, idx in enumerate(ordered_idx)] + ordered_idx = sorted(list(range(n)), key=lambda i: ry[i]) + rx_ordered_by_ry = [(rx[idx], i) for i, idx in enumerate(ordered_idx)] def merge_sort(arr): if len(arr) <= 1: @@ -82,38 +75,37 @@ def merge_sort(arr): j += 1 k += 1 return tauAP - - res = (2 - 2 * merge_sort(ry_ordered_by_rx) / (n - 1)) - 1 + res = (2 - 2 * merge_sort(rx_ordered_by_ry) / (n - 1)) - 1 return res def pearson_rank(self, x, y): - x = np.interp(x, (min(x), max(x)), (0,1)) - y = np.interp(y, (min(y), max(y)), (0,1)) - indices = sorted(list(range(len(x))), key=lambda idx : x[idx], reverse=True) + x = np.interp(x, (min(x), max(x)), (0, 1)) + y = np.interp(y, (min(y), max(y)), (0, 1)) + indices = sorted(list(range(len(x))), key=lambda idx: x[idx], reverse=True) x = x[indices] y = y[indices] - x_diff = x.reshape(1,-1) - x.reshape(-1,1) - y_diff = y.reshape(1,-1) - y.reshape(-1,1) + x_diff = x.reshape(1, -1) - x.reshape(-1, 1) + y_diff = y.reshape(1, -1) - y.reshape(-1, 1) den = x[1:].sum() pr = 0 - mask = np.tril(np.ones((len(x),len(x))),k=-1) - xy = x_diff*y_diff*mask - xx = x_diff*x_diff*mask - yy = y_diff*y_diff*mask + mask = np.tril(np.ones((len(x), len(x))), k=-1) + xy = x_diff * y_diff * mask + xx = x_diff * x_diff * mask + yy = y_diff * y_diff * mask xy = xy.sum(axis=1)[1:] xx = xx.sum(axis=1)[1:] yy = yy.sum(axis=1)[1:] - den_i = np.sqrt(xx)*np.sqrt(yy) - den_i[den_i==0]=1e-5 - res = (xy*x[1:]/den_i).sum()/den + den_i = np.sqrt(xx) * np.sqrt(yy) + den_i[den_i == 0] = 1e-5 + res = (xy * x[1:] / den_i).sum() / den return res def kl_div(self, x, y): x = np.array(x) - min(x) + 1e-5 y = np.array(y) - min(y) + 1e-5 - x = x/x.sum() - y = y/y.sum() - return -(stats.entropy(x,y)+stats.entropy(y,x))/2 + x = x / x.sum() + y = y / y.sum() + return -(stats.entropy(x, y) + stats.entropy(y, x)) / 2 def _query_differences(self, run1, run2, *args, **kwargs): """ @@ -124,8 +116,8 @@ def _query_differences(self, run1, run2, *args, **kwargs): :return: The union of top k qids in both runs, sorted by the order in which the queries appear in run 1 ^ This is because run 1 appears on the left hand side in the web ui """ - topk = self.config["topk"] - metric = self.config["metric"] + topk = self.topk + metric = self.metric qids = run1.keys() & run2.keys() if not qids: raise ValueError("run1 and run2 have no shared qids") @@ -133,13 +125,14 @@ def _query_differences(self, run1, run2, *args, **kwargs): id2measure = {} for qid in qids: from collections import defaultdict - min_value = min(min(run1[qid].values()), min(run2[qid].values()))-1e-5 + min_value = min(min(run1[qid].values()), min(run2[qid].values())) - 1e-5 doc_score_1 = defaultdict(lambda: min_value, run1[qid]) doc_score_2 = defaultdict(lambda: min_value, run2[qid]) doc_ids_1 = doc_score_1.keys() doc_ids_2 = doc_score_2.keys() doc_ids_union = set(doc_ids_1).union(set(doc_ids_2)) - doc_ids_union = sorted(list(doc_ids_union), key=lambda id: (doc_score_1[id] + doc_score_2[id]), reverse=True) + doc_ids_union = sorted(list(doc_ids_union), key=lambda id: (doc_score_1[id] + doc_score_2[id]), + reverse=True) union_score1 = [doc_score_1[doc_id] for doc_id in doc_ids_union] union_score2 = [doc_score_2[doc_id] for doc_id in doc_ids_union] if metric == "weightedtau": @@ -148,14 +141,15 @@ def _query_differences(self, run1, run2, *args, **kwargs): tau = self.tauap_fast(union_score1, union_score2) elif metric == "spearmanr": tau, p_value = stats.spearmanr(union_score1, union_score2) - elif metric == "pearsonr": - tau = (self.pearson_rank(union_score1, union_score2)+self.pearson_rank(union_score2, union_score1))/2 + elif metric == "pearsonrank": + tau = (self.pearson_rank(union_score1, union_score2) + self.pearson_rank(union_score2, + union_score1)) / 2 elif metric == "kldiv": tau = self.kl_div(union_score1, union_score2) else: - raise ValueError("Metric {} not supported for the measure {}".format(self.config["metric"], self.module_name)) + raise ValueError("Metric {} not supported for the measure {}".format(self.metric, "metric")) id2measure[qid] = tau qids = sorted(qids, key=lambda x: id2measure[x]) qids = qids[:topk] id2measure = {idx: id2measure[idx] for idx in qids} - return qids, id2measure, metric + return qids, id2measure, metric, None diff --git a/diffir/run.py b/diffir/run.py index bc064d1..5e8c66c 100644 --- a/diffir/run.py +++ b/diffir/run.py @@ -1,99 +1,85 @@ import os -import sys -from collections import defaultdict - import numpy as np import argparse import json from collections import defaultdict from tqdm import tqdm - from intervaltree import IntervalTree, Interval - -from docopt import docopt - -from profane import config_list_to_dict, constants, ConfigOption, Dependency, ModuleBase - from mako.template import Template from rich.console import Console from rich.table import Table -from rich import print as rprint from rich.prompt import Confirm -from rich.live import Live from rich.panel import Panel -from rich.layout import Layout import ir_datasets + +from diffir import QrelMeasure, TopkMeasure, CustomWeight, ExactMatchWeight from diffir.utils import load_trec_run _logger = ir_datasets.log.easy() -# specify a base package that profane should look for modules under -constants["BASE_PACKAGE"] = "diffir" - def main(): - help = """ - Usage: - run.py RUN1 RUN2 [options] [([with] CONFIG...)] - run.py (-h | --help) - - - Options: - -c --cli CLI mode (default) - -w --web webui mode - -h --help Print this help message and exit. - - - Arguments: - RUN1 First run file to compare - RUN2 Second run file to compare - CONFIG Configuration assignments of the form foo.bar=17 - - """ - parser = argparse.ArgumentParser() - parser.add_argument("runfiles", nargs="+") - parser.add_argument("-c", "--cli", dest="cli", action="store_true") - parser.add_argument("-w", "--web", dest="web", action="store_true") - parser.add_argument("--config", dest="config", nargs="*") - + parser.add_argument("runfiles", nargs="+", help="run file(s) to display and compare") + parser.add_argument("-c", "--cli", dest="cli", action="store_true", help="output to CLI (default)") + parser.add_argument("-w", "--web", dest="web", action="store_true", help="output HTML file for WebUI") + parser.add_argument("--dataset", dest="dataset", type=str, required=True, help="dataset identifier from ir_datasets") + parser.add_argument("--measure", dest="measure", type=str, default="tauap", help="measure for ranking difference (qrel, tauap, weightedtau)") + parser.add_argument("--metric", dest="metric", type=str, default="nDCG@10", help="metric to report and used with qrel measure") + parser.add_argument("--topk", dest="topk", type=int, default=50, help="number of queries to compare") + parser.add_argument("--weights_1", dest="weights_1", type=str, default=None, required=False) + parser.add_argument("--weights_2", dest="weights_2", type=str, default=None, required=False) args = parser.parse_args() - diff(args.runfiles, args.config, cli=args.cli, web=args.web) + config = {"dataset": args.dataset, "measure": args.measure, "metric": args.metric, "topk": args.topk, + "weight": {"weights_1": args.weights_1, "weights_2": args.weights_2}} + if not (args.cli or args.web): + args.cli = True # default + diff(args.runfiles, config, cli=args.cli, web=args.web) def diff(runs, config, cli, web, print_html=True): - config = config_list_to_dict(config) if config else {} - - # hack to automatically add weight files when available and not already specified - config_with_defaults = MainTask(config, build=False).config for i, run in enumerate(runs): - if f"weights_{i+1}" in config_with_defaults["weight"] and config_with_defaults["weight"][f"weights_{i+1}"] is None: + if config["weight"][f"weights_{i + 1}"] is None: if os.path.exists(run + ".diffir"): - config.setdefault("weight", {})[f"weights_{i+1}"] = run + ".diffir" - - task = MainTask(config) + _logger.info("Found weight file at {}".format(run+".diffir")) + config["weight"][f"weights_{i + 1}"] = run + ".diffir" + task = MainTask(**config) if cli: task.cli(runs) if web: html = task.web(runs) if print_html: print(html) - return task.config, html + return config, html -class MainTask(ModuleBase): +class MainTask: module_type = "task" module_name = "main" - config_spec = [ - ConfigOption(key="dataset", default_value="none", description="TODO"), - ConfigOption(key="queries", default_value="none", description="TODO"), - ] - dependencies = [ - Dependency(key="measure", module="measure", name="topk"), - Dependency(key="weight", module="weight", name="exactmatch"), - ] - - def create_query_objects(self, run_1, run_2, qids, qid2diff, metric_name, dataset): + + def __init__(self, dataset="none", queries="none", measure="topk", metric="weighted_tau", topk=3, weight={}): + self.dataset = dataset + self.queries = queries + if measure == "qrel": + self.measure = QrelMeasure(metric, topk) + elif measure == "tauap": + self.measure = TopkMeasure("tauap", topk) + elif measure == "weightedtau": + self.measure = TopkMeasure("weightedtau", topk) + elif measure == "spearmanr": + self.measure = TopkMeasure("spearmanr", topk) + elif measure == "pearsonrank": + self.measure = TopkMeasure("pearsonrank", topk) + elif measure == "kldiv": + self.measure = TopkMeasure("kldiv", topk) + else: + raise ValueError("Measure {} is not supported".format(measure)) + if weight["weights_1"] or weight["weights_2"]: + self.weight = CustomWeight(weight["weights_1"], weight["weights_2"]) + else: + self.weight = ExactMatchWeight() + + def create_query_objects(self, run_1, run_2, qids, qid2diff, metric_name, dataset, qid2qrelscores=None): """ TODO: Need a better name This method takes in 2 runs and a set of qids, and constructs a dict for each qid (format specified below) @@ -138,7 +124,10 @@ def create_query_objects(self, run_1, run_2, qids, qid2diff, metric_name, datase ) fields = query._asdict() - fields["metric"]={"name": metric_name, "value": qid2diff[query.query_id]} + fields["metric"] = {"name": metric_name, "value": qid2diff[query.query_id]} + if qid2qrelscores: + fields[f'Run1 {metric_name}'] = qid2qrelscores[query.query_id][0] + fields[f'Run2 {metric_name}'] = qid2qrelscores[query.query_id][1] qrels_for_query = qrels.get(query.query_id, {}) run_1_for_query = [] for rank, (doc_id, score) in enumerate(run_1[query.query_id].items()): @@ -218,7 +207,8 @@ def merge_weights(self, run1_for_query, run_2_for_query): for segment in doc_id2weights[doc_id]["run2"].get(field, []): t.add(Interval(segment[0], segment[1], {"run2": segment[2]})) t.split_overlaps() - t.merge_equals(lambda old_dict, new_dict: old_dict.update(new_dict) or old_dict, {"run1": None, "run2": None}) + t.merge_equals(lambda old_dict, new_dict: old_dict.update(new_dict) or old_dict, + {"run1": None, "run2": None}) merged_intervals = sorted([(i.begin, i.end, i.data) for i in t], key=lambda x: (x[0], x[1])) merged_weights[doc_id][field] = merged_intervals @@ -257,7 +247,7 @@ def find_snippet(self, weights, doc): top_field = field # reconstruct the snippet if top_snippet_score > 0: - snp_weights = sorted(weights[top_field])[top_range[0] : top_range[1]] + snp_weights = sorted(weights[top_field])[top_range[0]: top_range[1]] # start = max(snp_weights[0][1] - max(0, (MAX_SNIPPET_LEN - snp_weights[-1][0] + snp_weights[0][1])/2), 0) start = max(snp_weights[0][0] - 5, 0) stop = start + MAX_SNIPPET_LEN @@ -292,7 +282,7 @@ def create_doc_objects(self, query_objects, dataset): doc_ids_to_fetch.add(listed_doc["doc_id"]) for doc in _logger.pbar( - dataset.docs_store().get_many_iter(doc_ids_to_fetch), desc="Docs iter", total=len(doc_ids_to_fetch) + dataset.docs_store().get_many_iter(doc_ids_to_fetch), desc="Docs iter", total=len(doc_ids_to_fetch) ): doc_objects[doc.doc_id] = doc._asdict() @@ -307,15 +297,15 @@ def json(self, run_1_fn, run_2_fn=None): run_1 = load_trec_run(run_1_fn) run_2 = load_trec_run(run_2_fn) if run_2_fn is not None else None - dataset = ir_datasets.load(self.config["dataset"]) + dataset = ir_datasets.load(self.dataset) assert dataset.has_docs() assert dataset.has_queries() # TODO: handle the case without qrels assert dataset.has_queries() - - diff_queries, qid2diff, metric_name = self.measure.query_differences(run_1, run_2, dataset=dataset) + diff_queries, qid2diff, metric_name, qid2qrelscores = self.measure.query_differences(run_1, run_2, dataset=dataset) + diff_queries, qid2diff, metric_name, qid2qrelscores = self.measure.query_differences(run_1, run_2, dataset=dataset) # _logger.info(diff_queries) - diff_query_objects = self.create_query_objects(run_1, run_2, diff_queries, qid2diff, metric_name, dataset) + diff_query_objects = self.create_query_objects(run_1, run_2, diff_queries, qid2diff, metric_name, dataset, qid2qrelscores=qid2qrelscores) doc_objects = self.create_doc_objects(diff_query_objects, dataset) return json.dumps( @@ -323,7 +313,7 @@ def json(self, run_1_fn, run_2_fn=None): "meta": { "run1_name": run_1_fn, "run2_name": run_2_fn, - "dataset": self.config["dataset"], + "dataset": self.dataset, "measure": self.measure.module_name, "weight": self.weight.module_name, "qrelDefs": dataset.qrels_defs(), @@ -344,7 +334,7 @@ def print_query_to_console(self, q, console): console.print(query_panel) def render_snippet_for_cli(self, doc_id, snp, docs): - snp_text = docs[doc_id][snp["field"]][snp["start"] : snp["stop"]] + snp_text = docs[doc_id][snp["field"]][snp["start"]: snp["stop"]] idx_change = 0 for s, e, w in snp["weights"]: s = s + idx_change @@ -450,7 +440,8 @@ def cli(self, runs): if len(runs) == 2: for current_index in range(len(queries)): self.cli_compare_one_query( - console, queries[current_index], 0, None, docs, json_data["meta"]["run1_name"], json_data["meta"]["run2_name"] + console, queries[current_index], 0, None, docs, json_data["meta"]["run1_name"], + json_data["meta"]["run2_name"] ) ans = Confirm.ask("Want to see the next query?") if not ans: @@ -458,7 +449,8 @@ def cli(self, runs): else: with console.pager(): for current_index in range(len(queries)): - self.cli_display_one_query(console, queries[current_index], 0, None, docs, json_data["meta"]["run1_name"]) + self.cli_display_one_query(console, queries[current_index], 0, None, docs, + json_data["meta"]["run1_name"]) def web(self, runs): json_data = self.json(*runs) diff --git a/diffir/templates/template.html b/diffir/templates/template.html index d671344..bc1bc42 100644 --- a/diffir/templates/template.html +++ b/diffir/templates/template.html @@ -618,4 +618,4 @@
- \ No newline at end of file + diff --git a/diffir/test/test_tauap.py b/diffir/test/test_tauap.py new file mode 100644 index 0000000..4fd36f7 --- /dev/null +++ b/diffir/test/test_tauap.py @@ -0,0 +1,55 @@ +from diffir.measure.unsupervised import TopkMeasure + +class TestUnsupervisedMeasure: + def test_tauap_one(self): + measure = TopkMeasure() + x=[1,2,3] + y=[1,2,3] + tauap = measure.tauap(x,y) + tauap_fast = measure.tauap_fast(x,y) + assert tauap == tauap_fast + assert tauap == 1 + def test_tauap_two(self): + measure = TopkMeasure() + x=[1,2,3] + y=[3,2,1] + tauap = measure.tauap(x,y) + tauap_fast = measure.tauap_fast(x,y) + assert tauap == tauap_fast + assert tauap == -1 + + def test_tauap_three(self): + measure = TopkMeasure() + x=[1,2,5] + y=[1,9,10] + tauap = measure.tauap(x,y) + tauap_fast = measure.tauap_fast(x,y) + assert tauap == tauap_fast + assert tauap == 1 + + def test_tauap_three(self): + measure = TopkMeasure() + x = [3, 1, 2] + y = [1, 2, 3] + tauap = measure.tauap(x,y) + tauap_fast = measure.tauap_fast(x,y) + assert tauap == tauap_fast + assert tauap == 0 + + def test_tauap_four(self): + measure = TopkMeasure() + x = [1, 2, 4, 3, 5] + y = [1, 2, 3, 4, 5] + tauap = measure.tauap(x,y) + tauap_fast = measure.tauap_fast(x,y) + assert tauap == tauap_fast + assert tauap == 0.75 + + def test_tauap_five(self): + measure = TopkMeasure() + x = [2, 1, 3, 4, 5] + y = [1, 2, 3, 4, 5] + tauap = measure.tauap(x,y) + tauap_fast = measure.tauap_fast(x,y) + assert tauap == tauap_fast + assert tauap == 0.875 diff --git a/diffir/weight/__init__.py b/diffir/weight/__init__.py index d640dec..da9623d 100644 --- a/diffir/weight/__init__.py +++ b/diffir/weight/__init__.py @@ -1,14 +1,3 @@ -from profane import ModuleBase, import_all_modules - - -class Weight(ModuleBase): - module_type = "weight" - - # TODO finalize API +class Weight: def score_document_regions(self, query, doc, run_idx): raise NotImplementedError() - - -# TODO this is going to break once we introduce optional modules. need a way for them to fail gracefully. -# or to enumerate/register them without importing the py file? -import_all_modules(__file__, __package__) diff --git a/diffir/weight/custom.py b/diffir/weight/custom.py index cb0052f..49636e1 100644 --- a/diffir/weight/custom.py +++ b/diffir/weight/custom.py @@ -1,31 +1,29 @@ import json -from intervaltree import IntervalTree -from nltk import word_tokenize -from nltk.corpus import stopwords -from profane import ModuleBase, Dependency, ConfigOption import ir_datasets from . import Weight - _logger = ir_datasets.log.easy() -@Weight.register class CustomWeight(Weight): module_name = "custom" - config_spec = [ - # TODO: is there a better way to handle these args? There's strlist, but that's ugly for file paths. - # Maybe we could infer them if they are named with the run prefix? Can we get the run paths here? - ConfigOption(key="weights_1", default_value=None, description="TODO"), - ConfigOption(key="norm_1", default_value="minmax", description="TODO"), - ConfigOption(key="weights_2", default_value=None, description="TODO"), - ConfigOption(key="norm_2", default_value="minmax", description="TODO"), - ] + + def __init__(self, weights_1, weights_2, norm_1="minmax", norm_2="minmax"): + ''' + Customed weights file from ranking models + :param weights_1: + :param weights_2: + :param norm_1: + :param norm_2: + ''' + self.weights = [weights_1, weights_2] + self.norms = [norm_1, norm_2] + self.build() def build(self): self._cache = {} for run_idx in [0, 1]: - weights_file, norm = self.config[f"weights_{run_idx+1}"], self.config[f"norm_{run_idx+1}"] + weights_file, norm = self.weights[run_idx], self.norms[run_idx] if weights_file is None: _logger.warn(f"missing weights.weights_{run_idx + 1}") self._cache[run_idx] = {} diff --git a/diffir/weight/unsupervised.py b/diffir/weight/unsupervised.py index dc85431..f6d7526 100644 --- a/diffir/weight/unsupervised.py +++ b/diffir/weight/unsupervised.py @@ -3,20 +3,16 @@ from intervaltree import IntervalTree from nltk import word_tokenize from nltk.corpus import stopwords -from profane import ModuleBase, Dependency, ConfigOption from . import Weight import ahocorasick -@Weight.register class ExactMatchWeight(Weight): module_name = "exactmatch" - config_spec = [ - ConfigOption(key="skip_stopwords", default_value=True, description="TODO"), - ConfigOption( - key="queryfield", default_value="", value_type="strlist", description="The query field that is used for highlighting" - ), - ] + + def __init__(self, queryfield="", skip_stopwords=True): + self.queryfield = queryfield + self.skip_stopwords = skip_stopwords def fast_score_document_regions(self, query, doc, run_idx): """ @@ -35,7 +31,7 @@ def fast_score_document_regions(self, query, doc, run_idx): except LookupError: nltk.download("stopwords") result = {} - stops = stopwords.words("english") if self.config["skip_stopwords"] else None + stops = stopwords.words("english") if self.skip_stopwords else None query_tokens = set() for qfield_value in query: query_tokens.update( @@ -59,7 +55,7 @@ def fast_score_document_regions(self, query, doc, run_idx): for field, values in list(result.items()): tree = IntervalTree() for start, stop in values: - tree[start : stop + 1] = 1 + tree[start: stop + 1] = 1 tree.merge_overlaps() result[field] = sorted([[i.begin, i.end, 1.0] for i in tree]) return result @@ -77,10 +73,10 @@ def score_document_regions(self, query, doc, run_idx, fast=False): except LookupError: nltk.download("stopwords") result = {} - stops = stopwords.words("english") if self.config["skip_stopwords"] else None + stops = stopwords.words("english") if self.skip_stopwords else None qfield_values = [] - specified_qfields = list(filter(None, self.config["queryfield"])) + specified_qfields = list(filter(None, self.queryfield)) # Choose a query field to do the highlighting with if specified_qfields: @@ -104,7 +100,7 @@ def score_document_regions(self, query, doc, run_idx, fast=False): if stops and word.lower() in stops: continue for dfield, dvalue in zip(doc._fields, doc): - if not isinstance(dvalue, str): # TODO: how to handle other field types (like the structured CORD19 docs)? + if not isinstance(dvalue, str): continue # skip non-strings for now if dfield not in result: result[dfield] = [] diff --git a/requirements.txt b/requirements.txt index 8e76f13..0ae48b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,11 @@ -profane>=0.2.3,<0.3 +nltk>=3.5 +ir_measures>=0.1.4 mako~=1.1 ir_datasets>=0.3.1 pytrec_eval>=0.5 intervaltree>=3.1.0 rich>=9.13.0 pyahocorasick>=1.4.1 +numpy +scipy +pandas