Skip to content

Commit

Permalink
feat(classifiers): implement KNN algorithm
Browse files Browse the repository at this point in the history
  • Loading branch information
matthieu committed Aug 8, 2019
1 parent b2bc8f1 commit 27e1875
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Alchina is a Machine Learning framework.

- Linear classifier
- Ridge classifier
- K-Nearest Neighbors

**Clusters**

Expand Down
53 changes: 53 additions & 0 deletions alchina/classifiers/knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""K-Nearest Neighbors"""

import numpy as np

from collections import Counter

from alchina.exceptions import NotFitted
from alchina.metrics import accuracy_score


class KNNClassifier(object):
"""K-Nearest Neighbors algorithm"""

def __init__(self, n_neighbors=3):
self.n_neighbors = n_neighbors

self.X_fit = None
self.y_fit = None

def euclidian(self, a, b):
"""Compute the euclidian distance between two samples."""
return np.linalg.norm(a - b)

def fit(self, X, y):
"""Train the model."""
self.X_fit = X
self.y_fit = y

def predict(self, X):
"""Predict a target given features."""
if self.X_fit is None or self.y_fit is None:
raise NotFitted("the model must be fitted before usage")

labels = []
for x in X:
distances_labels = [
(self.euclidian(x, x_fit), y_fit)
for x_fit, y_fit in zip(self.X_fit, self.y_fit)
]
neighbors = sorted(distances_labels, key=lambda d: d[0])[: self.n_neighbors]
neighbors_labels = [neighbor[1][0] for neighbor in neighbors]
labels.append(
sorted(
neighbors_labels, key=Counter(neighbors_labels).get, reverse=True
)[0]
)
return np.array(labels).reshape(-1, 1)

def score(self, X, y):
"""Score of the model."""
if self.X_fit is None or self.y_fit is None:
raise NotFitted("the model must be fitted before usage")
return accuracy_score(self.predict(X), y)
67 changes: 67 additions & 0 deletions tests/classifiers/test_knn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
"""K-Nearest Neighbors tests."""

import numpy as np
import pytest

from alchina.classifiers import KNNClassifier
from alchina.exceptions import NotFitted


# --- Linear classifier ---


def test_knn_classifier():
"""Test of `KNNClassifier` class."""
knn = KNNClassifier(1)

X = np.array([[0], [1]])
y = np.array([[0], [1]])

knn.fit(X, y)

assert knn.score(X, y) == 1


def test_knn_classifier_predict():
"""Test of `KNNClassifier` class with a prediction."""
knn = KNNClassifier(1)

X = np.array([[0], [1]])
y = np.array([[0], [1]])

knn.fit(X, y)

assert np.equal(knn.predict(np.array([0])), np.array([0]))


def test_knn_classifier_multiclass():
"""Test of `LinearClassifier` with multiclass."""
knn = KNNClassifier(1)

X = np.array([[0], [1], [2]])
y = np.array([[0], [1], [2]])

knn.fit(X, y)

assert knn.score(X, y) == 1


def test_knn_classifier_predict_not_fitted():
"""Test of `KNNClassifier` class with prediction without fit."""
knn = KNNClassifier(1)

X = np.array([[0], [1]])

with pytest.raises(NotFitted):
knn.predict(X)


def test_knn_classifier_score_not_fitted():
"""Test of `KNNClassifier` class with score calculation without fit."""
knn = KNNClassifier(1)

X = np.array([[0], [1]])
y = np.array([[0], [1]])

with pytest.raises(NotFitted):
knn.score(X, y) == 1

0 comments on commit 27e1875

Please sign in to comment.