Skip to content

Commit

Permalink
fix warp accuract test bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ummae committed Jul 11, 2020
1 parent 411aecd commit d074757
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 12 deletions.
7 changes: 4 additions & 3 deletions benchmark/accuracy_warp.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,18 @@ WARP | 0.17361 | 0.62401 | 0.25332 | 0.12941
Please run following command to reproduce this experiment: `$> python test_accuracy.py compare_warp_brp ml20m`

## Compare with LightFM
/(this experiment is not run with hyper-parameter tuning)/

**Parameters**

- `num_iters`: 100
- `num_iters`: 30
- `d`: 40

**Top10** accuracy of validation samples for MovieLens100K:

method | NDCG | AUC | ACCURACY | MAP |
-- | -- | -- | -- | --
BUFFALO| 0.16562 | 0.62012 | 0.00610| 0.16562
LIGHTFM| 0.03657 | 0.50008 | 0.24548| 0.00365
BUFFALO| 0.15890| 0.62473| 0.25480| 0.11059
LIGHTFM| 0.15827| 0.61191| 0.22909| 0.12027

Please run following command to reproduce this experiment: `$> python test_accuracy.py accuracy warp ml100k --libs=buffalo,lightfm`
18 changes: 10 additions & 8 deletions benchmark/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def get_option(self, lib_name, algo_name, **kwargs):
from buffalo.algo.options import BPRMFOption
opt = BPRMFOption().get_default_option()
opt.update({'d': kwargs.get('d', 100),
'lr': kwargs.get('lr', 0.05),
'validation': kwargs.get('validation'),
'num_iters': kwargs.get('num_iters', 10),
'num_workers': kwargs.get('num_workers', 10),
Expand All @@ -96,6 +97,7 @@ def get_option(self, lib_name, algo_name, **kwargs):
from buffalo.algo.options import WARPOption
opt = WARPOption().get_default_option()
opt.update({'d': kwargs.get('d', 100),
'lr': kwargs.get('lr', 0.05),
'validation': kwargs.get('validation'),
'num_iters': kwargs.get('num_iters', 10),
'max_trials': 100,
Expand Down Expand Up @@ -168,7 +170,7 @@ def __init__(self):

def get_database(self, name, **kwargs):
if name in ['ml20m', 'ml100k', 'kakao_reco_730m', 'kakao_brunch_12m']:
db = h5py.File(DB[name])
db = h5py.File(DB[name], 'r')
ratings = db_to_coo(db)
db.close()
return ratings
Expand Down Expand Up @@ -314,7 +316,7 @@ def __init__(self):

def get_database(self, name, **kwargs):
if name in ['ml20m', 'ml100k', 'kakao_reco_730m', 'kakao_brunch_12m']:
db = h5py.File(DB[name])
db = h5py.File(DB[name], 'r')
ratings = db_to_coo(db)
db.close()
return ratings
Expand All @@ -327,8 +329,8 @@ def bpr(self, database, **kwargs):
opts = self.get_option('lightfm', 'bpr', **kwargs)
data = self.get_database(database, **kwargs)
bpr = LightFM(loss='bpr',
no_components=kwargs.get('num_workers'))
elapsed, mem_info = self.run(bpr.fit, data, data, **opts)
no_components=kwargs.get('d'))
elapsed, mem_info = self.run(bpr.fit, data, **opts)
if kwargs.get('return_instance'):
return bpr
bpr = None
Expand All @@ -354,12 +356,12 @@ def warp(self, database, **kwargs):
data = self.get_database(database, **kwargs)
warp = LightFM(loss='warp',
learning_schedule='adagrad',
no_components=kwargs.get('num_workers'),
no_components=kwargs.get('d'),
max_sampled=100)
elapsed, mem_info = self.run(warp.fit, data, data, **opts)
elapsed, mem_info = self.run(warp.fit, data, **opts)
if kwargs.get('return_instance'):
return warp
bpr = None
warp = None
return elapsed, mem_info


Expand Down Expand Up @@ -482,7 +484,7 @@ def __init__(self):

def get_database(self, name, **kwargs):
if name in ['ml20m', 'ml100k', 'kakao_reco_730m', 'kakao_brunch_12m']:
db = h5py.File(DB[name])
db = h5py.File(DB[name], 'r')
ratings = db_to_dataframe(db, kwargs.get('spark'), kwargs.get('context'))
db.close()
return ratings
Expand Down
3 changes: 2 additions & 1 deletion benchmark/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ def _get_validation_score(algo_name, lib, database):
'validation': {'topk': 10},
'd': 40},
'warp': {'num_workers': 8,
'lr': 0.2,
'batch_mb': 4098,
'compute_loss_on_training': False,
'num_iters': 100,
'num_iters': 30,
'validation': {'topk': 10},
'd': 40}
}
Expand Down

0 comments on commit d074757

Please sign in to comment.