Skip to content

Commit

Permalink
Add test.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 11, 2020
1 parent b383798 commit 8d3bae4
Showing 1 changed file with 22 additions and 15 deletions.
37 changes: 22 additions & 15 deletions tests/python-gpu/test_gpu_basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,15 @@
class TestGPUBasicModels(unittest.TestCase):
cputest = test_bm.TestModels()

def test_eta_decay_gpu_hist(self):
self.cputest.run_eta_decay('gpu_hist')

def test_deterministic_gpu_hist(self):
kRows = 1000
kCols = 64
kClasses = 4
# Create large values to force rounding.
X = np.random.randn(kRows, kCols) * 1e4
y = np.random.randint(0, kClasses, size=kRows)

def run_cls(self, X, y, deterministic):
cls = xgb.XGBClassifier(tree_method='gpu_hist',
deterministic_histogram=True,
deterministic_histogram=deterministic,
single_precision_histogram=True)
cls.fit(X, y)
cls.get_booster().save_model('test_deterministic_gpu_hist-0.json')

cls = xgb.XGBClassifier(tree_method='gpu_hist',
deterministic_histogram=True,
deterministic_histogram=deterministic,
single_precision_histogram=True)
cls.fit(X, y)
cls.get_booster().save_model('test_deterministic_gpu_hist-1.json')
Expand All @@ -40,7 +30,24 @@ def test_deterministic_gpu_hist(self):
with open('test_deterministic_gpu_hist-1.json', 'r') as fd:
model_1 = fd.read()

assert hash(model_0) == hash(model_1)

os.remove('test_deterministic_gpu_hist-0.json')
os.remove('test_deterministic_gpu_hist-1.json')

return hash(model_0), hash(model_1)

def test_eta_decay_gpu_hist(self):
self.cputest.run_eta_decay('gpu_hist')

def test_deterministic_gpu_hist(self):
kRows = 1000
kCols = 64
kClasses = 4
# Create large values to force rounding.
X = np.random.randn(kRows, kCols) * 1e4
y = np.random.randint(0, kClasses, size=kRows) * 1e4

model_0, model_1 = self.run_cls(X, y, True)
assert hash(model_0) == hash(model_1)

model_0, model_1 = self.run_cls(X, y, False)
assert hash(model_0) != hash(model_1)

0 comments on commit 8d3bae4

Please sign in to comment.