Skip to content

Commit

Permalink
Fix tests (#516)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaciekEO committed Feb 17, 2022
1 parent 600c128 commit e4410d6
Show file tree
Hide file tree
Showing 5 changed files with 355 additions and 366 deletions.
12 changes: 6 additions & 6 deletions tests/tests_algorithms/test_catboost.py
Expand Up @@ -36,13 +36,13 @@ def setUpClass(cls):
def test_reproduce_fit(self):
metric = Metric({"name": "mse"})
prev_loss = None
for _ in range(3):
for _ in range(2):
model = CatBoostAlgorithm(self.params)
model.fit(self.X, self.y)
y_predicted = model.predict(self.X)
loss = metric(self.y, y_predicted)
if prev_loss is not None:
assert_almost_equal(prev_loss, loss)
assert_almost_equal(prev_loss, loss, decimal=3)
prev_loss = loss

def test_get_metric_name(self):
Expand Down Expand Up @@ -79,13 +79,13 @@ def setUpClass(cls):
def test_reproduce_fit(self):
metric = Metric({"name": "logloss"})
prev_loss = None
for _ in range(3):
for _ in range(2):
model = CatBoostAlgorithm(self.params)
model.fit(self.X, self.y)
y_predicted = model.predict(self.X)
loss = metric(self.y, y_predicted)
if prev_loss is not None:
assert_almost_equal(prev_loss, loss)
assert_almost_equal(prev_loss, loss, decimal=3)
prev_loss = loss

def test_fit_predict(self):
Expand All @@ -97,7 +97,7 @@ def test_fit_predict(self):
y_predicted = cat.predict(self.X)
loss = metric(self.y, y_predicted)
if loss_prev is not None:
assert_almost_equal(loss, loss_prev)
assert_almost_equal(loss, loss_prev, decimal=3)
loss_prev = loss

def test_copy(self):
Expand Down Expand Up @@ -137,7 +137,7 @@ def test_save_and_load(self):

y_predicted = cat2.predict(self.X)
loss2 = metric(self.y, y_predicted)
assert_almost_equal(loss, loss2)
assert_almost_equal(loss, loss2, decimal=3)

def test_get_metric_name(self):
model = CatBoostAlgorithm(self.params)
Expand Down
1 change: 1 addition & 0 deletions tests/tests_algorithms/test_lightgbm.py
Expand Up @@ -37,6 +37,7 @@ def setUpClass(cls):
"bagging_fraction": 0.8,
"bagging_freq": 1,
"seed": 1,
"early_stopping_rounds": 0,
}

def test_reproduce_fit(self):
Expand Down
File renamed without changes.

0 comments on commit e4410d6

Please sign in to comment.