In [20]:
from typing import List
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.datasets import make_regression
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import train_test_split

In [37]:
X, y = make_regression(n_features=100, noise=100)
X_train, X_test, y_train, y_test = train_test_split(X, y)

In [38]:
class KnnRegressor:
    def __init__(self, k_neighbors:int=5) -> None:
        self.k_neighbors = k_neighbors
    
    def load_data(self, X:List[List[float]], y:List[float]) -> None:
        self.X, self.y = X, y
        
    def predict(self, x:List[float]) -> List[float]:
        predict_values = np.array([])
        
        if x.shape[1] != self.X.shape[1]:
            raise ValueError("value x is not suitable")
            
        for item in x:
            norm_list = np.array(np.linalg.norm(self.X - item, axis=1))
            neighbors_indexes = norm_list.argsort()[:self.k_neighbors]
            
            predict_values = np.append(predict_values, np.mean(self.y[neighbors_indexes]))
        
        return predict_values

In [39]:
regressor = KnnRegressor(k_neighbors=5)
regressor.load_data(X=X_train, y=y_train)
regressor.predict(X_test)

array([-117.26609005,  103.12097673,  -96.79222601,  -85.02276022,
       -109.20802319,  -62.49546636,  -87.66991234,   -0.46029526,
         -1.47478273,   -4.33545319, -152.43243395,    1.36357421,
         -8.56286706,   -3.82373672,  -35.05493438, -197.30498676,
        -46.79734356, -184.7544811 ,  -61.12947952, -119.83326593,
         59.53742863,   -3.07001571, -173.7086403 , -172.41546476,
        -43.66090842])

In [40]:
neighbors = KNeighborsRegressor(n_neighbors=5)
neighbors.fit(X=X_train, y=y_train)
neighbors.predict(X_test)

array([-117.26609005,  103.12097673,  -96.79222601,  -85.02276022,
       -109.20802319,  -62.49546636,  -87.66991234,   -0.46029526,
         -1.47478273,   -4.33545319, -152.43243395,    1.36357421,
         -8.56286706,   -3.82373672,  -35.05493438, -197.30498676,
        -46.79734356, -184.7544811 ,  -61.12947952, -119.83326593,
         59.53742863,   -3.07001571, -173.7086403 , -172.41546476,
        -43.66090842])