Skip to content

Commit

Permalink
Fixed evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
makgyver committed Mar 5, 2019
1 parent 680e1bb commit 228d321
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 14 deletions.
3 changes: 2 additions & 1 deletion PRL/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from .genP import *
from .genF import *
from .genK import *
from .evaluation import *
from .prl import *
from .solvers import *

__all__ = ["prl", "genF", "genP", "evaluation", "solvers"]
__all__ = ["prl", "genF", "genP", "genK", "evaluation", "solvers"]
9 changes: 2 additions & 7 deletions PRL/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,13 @@ def confusion_matrix(prl, gen_pref_test):
for j, (p, f) in enumerate(prl.col_list):
if prl.Q[j] > 0.0:
for c in range(prl.dim):
if p[0][1] == c:
xp = prl.gen_pref.get_pref_value(p)[0][0]
sco[c] += prl.Q[j]*prl.gen_feat.get_feat_value(f, xp)*prl.gen_feat.get_feat_value(f, x)
if p[1][1] == c:
xn = prl.gen_pref.get_pref_value(p)[1][0]
sco[c] -= prl.Q[j]*prl.gen_feat.get_feat_value(f, xn)*prl.gen_feat.get_feat_value(f, x)
sco[c] += prl.Q[j]*prl.compute_entry(, ((x, c), (-x, c)), f)

y_max = np.argmax(sco)
conf_mat[y[i], y_max] += 1

return conf_mat


def accuracy(prl, gen_pref_test, conf_matrix=None):
if type(conf_matrix) != np.ndarray:
conf_matrix = confusion_matrix(prl, gen_pref_test)
Expand Down
18 changes: 14 additions & 4 deletions PRL/prl.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,17 @@ def compute_column(self, rq, f):
return np.dot(R, rq)

def compute_entry(self, p, q, f):
return np.dot(self.pref_repr(p, f), self.pref_repr(q, f))
rp = np.zeros(self.dim)
(x_p, y_p), (x_n, y_n) = p
rp[y_p] = +self.gen_feat.get_feat_value(f, x_p)
rp[y_n] = -self.gen_feat.get_feat_value(f, x_n)

rq = np.zeros(self.dim)
(x_p, y_p), (x_n, y_n) = q
rq[y_p] = +self.gen_feat.get_feat_value(f, x_p)
rq[y_n] = -self.gen_feat.get_feat_value(f, x_n)

return np.dot(rp, rq)

def _get_new_col(self):
"""Internal method that randomly pick a new column in such a way that its representation is not null and it is not already in the game matrix.
Expand Down Expand Up @@ -278,10 +288,10 @@ def compute_column(self, q, k):
"""
p_col = self.gen_pref.get_pref_value(q)
k_fun = self.gen_kernel.get_kernel_function(k)
R = np.zeros((self.n_rows, 1))
R = np.zeros(self.n_rows)
for i, r in enumerate(self.pref_list):
p_row = self.gen_pref.get_pref_value(r)
R[i, 0] = k_fun(p_col, p_row)
R[i] = k_fun(p_col, p_row)

return R

Expand Down Expand Up @@ -332,6 +342,6 @@ def fit(self, iterations=1000, verbose=False):
self.col_set.add((p, k))
self.M[:,j] = self.compute_column(p, k)
if verbose:
logging.info("# of kept columns: %d" %(np.sum(Q>0)))
logging.info("# of kept columns: %d\n" %(np.sum(Q>0)))

self.Q = Q
4 changes: 2 additions & 2 deletions run_prl.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def manage_options():
if "feat_gen" in data:
genf_class = getattr(__import__("prl.genF"), data['feat_gen'])
gen_col = genf_class(Xtr, *data['feat_gen_params'])
else:
elif "kernel_gen" in data:
genk_class = getattr(__import__("prl.genK"), data['kernel_gen'])
gen_col = genf_class(*data['kernel_gen_params'])
gen_col = genk_class(*data['kernel_gen_params'])

if data["pref_generator"] == "micro":
gen_pref_training = GenMicroP(Xtr, ytr)
Expand Down

0 comments on commit 228d321

Please sign in to comment.