Skip to content

Commit

Permalink
Fix imports (#79)
Browse files Browse the repository at this point in the history
* Fix imports

* Fix lint error

* Fix lint error(import rule)

* Fix lint error(Flake8)

* Fix lint error(cpplint)
  • Loading branch information
dkkim1005 committed Dec 18, 2023
1 parent d2d9b19 commit d068b23
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 20 deletions.
7 changes: 4 additions & 3 deletions buffalo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
__version__ = importlib.metadata.version('buffalo')

from buffalo.algo.als import ALS, inited_CUALS
from buffalo.algo.eals import EALS
from buffalo.algo.base import Algo
from buffalo.algo.bpr import BPRMF, inited_CUBPR
from buffalo.algo.cfr import CFR
from buffalo.algo.options import (AlgoOption, ALSOption, EALSOption, BPRMFOption,
CFROption, PLSIOption, W2VOption, WARPOption)
from buffalo.algo.eals import EALS
from buffalo.algo.options import (AlgoOption, ALSOption, BPRMFOption,
CFROption, EALSOption, PLSIOption, W2VOption,
WARPOption)
from buffalo.algo.plsi import PLSI
from buffalo.algo.w2v import W2V
from buffalo.algo.warp import WARP
Expand Down
24 changes: 12 additions & 12 deletions buffalo/algo/eals.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def init_factors(self):
for name, rows in [("P", header["num_users"]), ("Q", header["num_items"])]:
setattr(self, name, None)
setattr(self, name, np.random.normal(scale=1.0 / (self.opt.d ** 2),
size=(rows, self.vdim)).astype("float32"))
size=(rows, self.vdim)).astype("float32"))
self.P[:, self.opt.d:] = 0.0
self.Q[:, self.opt.d:] = 0.0
self.C = self._get_negative_weights()
Expand Down Expand Up @@ -151,17 +151,17 @@ def train(self, training_callback: Optional[Callable[[int, Dict[str, float]], No
train_t = time.time() - start_t
metrics = {"train_loss": loss}
if self.opt.validation and \
self.opt.evaluation_on_learning and \
self.periodical(self.opt.evaluation_period, i):
start_t = time.time()
self.validation_result = self.get_validation_results()
vali_t = time.time() - start_t
val_str = " ".join([f"{k}:{v:0.5f}" for k, v in self.validation_result.items()])
self.logger.info(f"Validation: {val_str} Elapsed {vali_t:0.3f} secs")
metrics.update({"val_%s" % k: v
for k, v in self.validation_result.items()})
if training_callback is not None and callable(training_callback):
training_callback(i, metrics)
self.opt.evaluation_on_learning and \
self.periodical(self.opt.evaluation_period, i):
start_t = time.time()
self.validation_result = self.get_validation_results()
vali_t = time.time() - start_t
val_str = " ".join([f"{k}:{v:0.5f}" for k, v in self.validation_result.items()])
self.logger.info(f"Validation: {val_str} Elapsed {vali_t:0.3f} secs")
metrics.update({"val_%s" % k: v
for k, v in self.validation_result.items()})
if training_callback is not None and callable(training_callback):
training_callback(i, metrics)
self.logger.info("Iteration %d: RMSE %.3f TotalLoss %.3f Elapsed %.3f secs" % (i + 1, loss, (total_loss / self._nnz), train_t))
best_loss = self.save_best_only(loss, best_loss, i)
if self.early_stopping(loss):
Expand Down
3 changes: 3 additions & 0 deletions include/buffalo/misc/blas.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#pragma once

#include <string>
#include <algorithm>


extern "C" {
// blas subroutines
Expand Down
7 changes: 2 additions & 5 deletions tests/algo/test_eals.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
import unittest

from loguru import logger

import buffalo
from buffalo import EALS, EALSOption, aux, set_log_level, MatrixMarketOptions, set_log_level
from buffalo import EALS, EALSOption, aux, set_log_level

from .base import TestBase

Expand Down Expand Up @@ -61,7 +58,7 @@ def test07_train_ml_20m(self):
opt.num_workers = 8
opt.validation = aux.Option({"topk": 10})
self._test7_train_ml_20m(EALS, opt)

def test08_serialization(self):
opt = EALSOption().get_default_option()
opt.d = 5
Expand Down

0 comments on commit d068b23

Please sign in to comment.