-
Notifications
You must be signed in to change notification settings - Fork 107
/
test_warp.py
85 lines (67 loc) · 2.54 KB
/
test_warp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import unittest
from buffalo import WARP, WARPOption, aux, set_log_level
from .base import TestBase
class TestWARP(TestBase):
def get_opts(self):
opt_dot = WARPOption().get_default_option()
opt_l2 = WARPOption().get_default_option()
opt_l2["score_func"] = "L2"
return [opt_dot, opt_l2]
def test00_get_default_option(self):
WARPOption().get_default_option()
self.assertTrue(True)
def test01_is_valid_option(self):
for opt in self.get_opts():
self.assertTrue(WARPOption().is_valid_option(opt))
opt["save_best"] = 1
self.assertRaises(RuntimeError, WARPOption().is_valid_option, opt)
opt["save_best"] = False
self.assertTrue(WARPOption().is_valid_option(opt))
def test02_init_with_dict(self):
set_log_level(3)
for opt in self.get_opts():
WARP(opt)
self.assertTrue(True)
def test03_init(self):
for opt in self.get_opts():
opt.d = 20
self._test3_init(WARP, opt)
def test04_train(self):
for opt in self.get_opts():
opt.d = 32
opt.max_tirals = 200
self._test4_train(WARP, opt)
def test05_validation(self):
for opt in self.get_opts():
opt.validation = aux.Option({"topk": 10})
self._test5_validation(WARP, opt, ndcg=0.03, map=0.02)
def test05_1_validation_with_callback(self,):
for opt in self.get_opts():
opt.d = 5
opt.num_iters = 15
opt.evaluation_period = 5
opt.validation = aux.Option({"topk": 10})
self._test5_1_validation_with_callback(WARP, opt)
def test06_topk(self):
for opt in self.get_opts():
opt.d = 10
opt.validation = aux.Option({"topk": 10})
self._test6_topk(WARP, opt)
def test07_train_ml_20m(self):
for opt in self.get_opts():
opt.num_workers = 8
opt.validation = aux.Option({"topk": 10})
self._test7_train_ml_20m(WARP, opt)
def test08_serialization(self):
for opt in self.get_opts():
opt.d = 20
opt.max_trials = 500
opt.validation = aux.Option({"topk": 10})
self._test8_serialization(WARP, opt)
def test09_compact_serialization(self):
for opt in self.get_opts():
opt.d = 10
opt.validation = aux.Option({"topk": 10})
self._test9_compact_serialization(WARP, opt)
if __name__ == "__main__":
unittest.main()