-
Notifications
You must be signed in to change notification settings - Fork 1
/
p@n.py.bk
40 lines (28 loc) · 837 Bytes
/
p@n.py.bk
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def predict_proba(self, seqs, batch_size = 64):
def _predict(seqs):
return self.model.predict(self.input_adapt(seqs))
if not isinstance(seqs[0], list):
seqs = [seqs, ]
proba = _predict(seqs)
return proba[0]
else:
kf = get_minibatches_idx(len(seqs), batch_size)
proba = []
for _, idx in kf:
proba.extend(_predict(np.asarray([seqs[i] for i in idx])))
proba = np.asarray(proba)
return proba
def precision_at_n(ys, pred_probs):
n_test = len(ys)
y_dim = len(pred_probs[0])
hit = [0 for i in range(y_dim)]
for y, probs in zip(ys, pred_probs):
eid_prob = sorted(enumerate(probs), key = lambda k:-k[1])
for i, item in enumerate(eid_prob):
eid, progs = item
if y == eid:
hit[i] += 1
for i in range(1, y_dim):
hit[i] += hit[i - 1]
prec = [float(hi) / n_test for hi in hit]
return prec