In [94]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.datasets import make_regression
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import train_test_split
from typing import List

In [104]:
X, y = make_regression(n_features=100, noise=100)
X_train, X_test, y_train, y_test = train_test_split(X, y)

In [124]:
class KnnRegressor:
    def __init__(self, k_neighbors:int=5) -> None:
        self.k_neighbors = k_neighbors
    
    def load_data(self, n_features:int=100, noise:int=50) -> None:
        self.n_features = n_features
        self.noise = noise
        
        self.X_train, self.y_train = X_train, y_train
        
    def predict(self, x:List[float]) -> List[float]:
        predict_values = np.array([])
        
        if x.shape[1] != self.X_train.shape[1]:
            raise ValueError("value x is not suitable")
            
        for item in x:
            norm_list = [(np.linalg.norm(item - self.X_train[i]), self.y_train[i]) for i in range(self.X_train.shape[0])]
            norm_list = sorted(norm_list, key=lambda x: x[0])
            
            predict_values = np.append(predict_values, np.mean(norm_list[:self.k_neighbors][1]))
        
        return predict_values

In [125]:
regressor = KnnRegressor(k_neighbors=5)
regressor.load_data()
regressor.predict(X_test)

array([ 130.99559906,  130.55163902, -157.33191198,   47.36803537,
       -112.87498987,  103.78405561,  -38.5058986 ,  -37.47208115,
       -133.80563017,  100.97322314,  162.15965592,  158.95575624,
         40.42840358, -252.68662432, -139.6796999 ,  -92.55923988,
        137.76069691,  -37.35029533,   92.83965975, -139.97561121,
        101.35911021,  -39.30206183, -140.00340949,   95.25429574,
        -42.40804932])

In [126]:
neighbors = KNeighborsRegressor(n_neighbors=5)
neighbors.fit(X=X_train, y=y_train)
neighbors.predict(X_test)

array([   8.65564616,  107.82328707, -127.483559  , -143.18368862,
        104.94730304,   22.21810192,  168.64780652,  -37.47894625,
        -93.32596568, -109.57477328,  233.87965358,   14.88385908,
          3.88285176, -208.60835286,  -67.22002238, -124.89822078,
        100.43247472,  -67.85029348,   31.51237607, -132.50863697,
          6.20285366,   62.98942148,  122.14995167,   11.75726584,
         21.74768354])