In [3]:
from typing import List
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.datasets import make_regression
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import train_test_split

In [4]:
X, y = make_regression(n_features=100, noise=100)
X_train, X_test, y_train, y_test = train_test_split(X, y)

In [5]:
class KnnRegressor:
    def __init__(self, k_neighbors:int=5) -> None:
        self.k_neighbors = k_neighbors
    
    def load_data(self, X:List[List[float]], y:List[float]) -> None:
        self.X, self.y = X, y
        
    def predict(self, x:List[float]) -> List[float]:
        predict_values = np.array([])
        
        if x.shape[1] != self.X.shape[1]:
            raise ValueError("value x is not suitable")
            
        for item in x:
            norm_list = np.linalg.norm(self.X - item, axis=1)
            neighbors_indexes = norm_list.argsort()[:self.k_neighbors]
            
            predict_values = np.append(predict_values, np.mean(self.y[neighbors_indexes]))
        
        return predict_values

In [6]:
regressor = KnnRegressor(k_neighbors=5)
regressor.load_data(X=X_train, y=y_train)
regressor.predict(X_test)

array([ 54.47953868,  21.08016235, -43.6175809 ,   8.90629141,
       162.43095931,   7.04222428, -71.95633579,   1.19332378,
       -97.79815511, -14.71249799, -64.77158741, -72.38039489,
       -20.54270229, -42.46035945,   7.13635113,  21.82649366,
        75.44731069, -37.03430724, -53.88567724,  18.01207748,
       -99.9969071 , -55.07143032, -12.67845186,  45.77681101,
       -55.57908308])

In [7]:
neighbors = KNeighborsRegressor(n_neighbors=5)
neighbors.fit(X=X_train, y=y_train)
neighbors.predict(X_test)

array([ 54.47953868,  21.08016235, -43.6175809 ,   8.90629141,
       162.43095931,   7.04222428, -71.95633579,   1.19332378,
       -97.79815511, -14.71249799, -64.77158741, -72.38039489,
       -20.54270229, -42.46035945,   7.13635113,  21.82649366,
        75.44731069, -37.03430724, -53.88567724,  18.01207748,
       -99.9969071 , -55.07143032, -12.67845186,  45.77681101,
       -55.57908308])