Skip to content

Commit

Permalink
Merge pull request #6488 from markotoplak/fix-ubuntu-tree
Browse files Browse the repository at this point in the history
[FIX] Fix classification trees for data with repeated feature values
  • Loading branch information
janezd committed Jul 13, 2023
2 parents 929bd7a + cbe2e11 commit d17f021
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Orange/classification/_tree_scorers.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def find_threshold_entropy(const double[:] x, const double[:] y,
curr_y = <int>y[idx[i]]
distr[curr_y] -= 1
distr[n_classes + curr_y] += 1
if curr_y != y[idx[i + 1]] and x[idx[i]] != x[idx[i + 1]]:
if x[idx[i]] != x[idx[i + 1]]:
entro = (i + 1) * log(i + 1) + (N - i - 1) * log(N - i - 1)
for j in range(2 * n_classes):
if distr[j]:
Expand Down
36 changes: 35 additions & 1 deletion Orange/tests/test_orangetree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np
import scipy.sparse as sp
from Orange.classification._tree_scorers import find_threshold_entropy

from Orange.data import Table, Domain, DiscreteVariable, ContinuousVariable
from Orange.classification.tree import \
Expand Down Expand Up @@ -34,7 +35,7 @@ def test_full_tree(self):
learn = self.TreeLearner(**self.no_pruning_args)
clf = learn(table)
pred = clf(table)
self.assertTrue(np.all(table.Y.flatten() == pred))
np.testing.assert_equal(table.Y.flatten(), pred)

def test_min_samples_split(self):
clf = self.TreeLearner(
Expand Down Expand Up @@ -448,3 +449,36 @@ def test_compile_and_run_cont_sparse(self):
[14, 2, 1]], dtype=float
))
np.testing.assert_equal(model.get_values(x), expected_values)


class TestScorers(unittest.TestCase):

def test_find_threshold_entropy(self):
x = np.array([1, 2, 3, 4], dtype=float)
y = np.array([0, 0, 1, 1], dtype=float)
ind = np.argsort(x, kind="stable")
e, t = find_threshold_entropy(x, y, ind, 2, 1)
self.assertAlmostEqual(e, 1)
self.assertEqual(t, 2.0)

def test_find_threshold_entropy_repeated(self):
x = np.array([1, 1, 1, 2, 2, 2], dtype=float)
y = np.array([0, 0, 0, 0, 1, 1], dtype=float)
ind = np.argsort(x, kind="stable")
e, t = find_threshold_entropy(x, y, ind, 2, 1)
self.assertAlmostEqual(e, 0.459147917027245)
self.assertEqual(t, 1.0)

x = np.array([1, 1, 1, 2, 2, 2], dtype=float)
y = np.array([0, 0, 1, 1, 1, 1], dtype=float)
ind = np.argsort(x, kind="stable")
e, t = find_threshold_entropy(x, y, ind, 2, 1)
self.assertAlmostEqual(e, 0.459147917027245)
self.assertEqual(t, 1.0)

x = np.array([1, 1, 1, 2, 2, 2], dtype=float)
y = np.array([0, 1, 1, 1, 1, 1], dtype=float)
ind = np.argsort(x, kind="stable")
e, t = find_threshold_entropy(x, y, ind, 2, 1)
self.assertAlmostEqual(e, 0.19087450462110966)
self.assertEqual(t, 1.0)

0 comments on commit d17f021

Please sign in to comment.