Skip to content

Commit

Permalink
Merge pull request #679 from Hk669/hk669
Browse files Browse the repository at this point in the history
added _classes for knn classifiers
  • Loading branch information
pplonski committed Dec 14, 2023
2 parents a99a33a + 4c29991 commit 8678d84
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
8 changes: 8 additions & 0 deletions supervised/algorithms/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ def fit(
else:
self.model.fit(X, y)

@property
def _classes(self):
# Returns the unique classes based on the fitted model
if hasattr(self.model, "classes_"):
return self.model.classes_
else:
return None


class KNeighborsAlgorithm(KNNFit, RegressorMixin):
algorithm_name = "k-Nearest Neighbors"
Expand Down
18 changes: 17 additions & 1 deletion tests/tests_algorithms/test_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@ class KNeighborsRegressorAlgorithmTest(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.X, cls.y = datasets.make_regression(
n_samples=100, n_features=5, n_informative=4, shuffle=False, random_state=0
n_samples=100,
n_features=5,
n_informative=4,
shuffle=False,
random_state=0
)

def test_reproduce_fit(self):
Expand Down Expand Up @@ -77,3 +81,15 @@ def test_is_fitted(self):
self.assertFalse(model.is_fitted())
model.fit(self.X, self.y)
self.assertTrue(model.is_fitted())

def test_classes_attribute(self):
params = {"ml_task": "binary_classification"}
model = KNeighborsAlgorithm(params)
model.fit(self.X,self.y)

try:
classes = model._classes
except AttributeError:
classes = None

self.assertTrue(np.array_equal(np.unique(self.y), classes))

0 comments on commit 8678d84

Please sign in to comment.