Skip to content

Commit

Permalink
implment symmetric KL divergence
Browse files Browse the repository at this point in the history
  • Loading branch information
thongnt99 committed Mar 13, 2021
1 parent edde557 commit 5df3ac2
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions diffir/measure/unsupervised.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,14 @@ def pearson_rank(self, x, y):
den_i[den_i==0]=1e-5
res = (xy*x[1:]/den_i).sum()/den
return res


def kl_div(self, x, y):
x = np.array(x) - min(x) + 1e-5
y = np.array(y) - min(y) + 1e-5
x = x/x.sum()
y = y/y.sum()
return -(stats.entropy(x,y)+stats.entropy(y,x))/2

def _query_differences(self, run1, run2, *args, **kwargs):
"""
:param run1: TREC run. Has the format {qid: {docid: score}, ...}
Expand All @@ -126,9 +133,9 @@ def _query_differences(self, run1, run2, *args, **kwargs):
id2measure = {}
for qid in qids:
from collections import defaultdict

doc_score_1 = defaultdict(lambda: 0, run1[qid])
doc_score_2 = defaultdict(lambda: 0, run2[qid])
min_value = min(min(run1[qid].values()), min(run2[qid].values()))-1e-5
doc_score_1 = defaultdict(lambda: min_value, run1[qid])
doc_score_2 = defaultdict(lambda: min_value, run2[qid])
doc_ids_1 = doc_score_1.keys()
doc_ids_2 = doc_score_2.keys()
doc_ids_union = set(doc_ids_1).union(set(doc_ids_2))
Expand All @@ -143,6 +150,8 @@ def _query_differences(self, run1, run2, *args, **kwargs):
tau, p_value = stats.spearmanr(union_score1, union_score2)
elif metric == "pearsonr":
tau = (self.pearson_rank(union_score1, union_score2)+self.pearson_rank(union_score2, union_score1))/2
elif metric == "kldiv":
tau = self.kl_div(union_score1, union_score2)
else:
raise ValueError("Metric {} not supported for the measure {}".format(self.config["metric"], self.module_name))
id2measure[qid] = tau
Expand Down

0 comments on commit 5df3ac2

Please sign in to comment.