In [72]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.datasets import make_regression
from sklearn.neighbors import KNeighborsRegressor
from sklearn.model_selection import train_test_split
from typing import List

In [73]:
class KnnRegressor:
    def __init__(self, k:int=5, n_features:int=100, noise:int=50) -> None:
        self.k = k
        self.n_features = n_features
        self.noise = noise
    
    def load_data(self) -> None:
        X, y = make_regression(n_features=self.n_features, noise=self.noise)
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(X, y)
        
    def predict(self, x:List) -> float:
        train_array = np.array(list(zip(self.X_train, self.y_train)), dtype=object)
        predict_values = np.array([])
        
        if x.shape[1] != self.X_train.shape[1]:
            raise ValueError("value x is not suitable")
            
        for item in x:
            norm_list = [(np.linalg.norm(item - train_array[i][0]), train_array[i][1]) for i in range(train_array.shape[0])]
            norm_list = sorted(norm_list, key=lambda x: x[0])
            
            predict_values = np.append(predict_values, np.mean(norm_list[:self.k]))
        
        return predict_values

In [74]:
regressor = KnnRegressor(k=5)
regressor.load_data()
regressor.predict(regressor.X_test)

array([ 51.90999666,  14.52065883,  21.38036438,  -1.09070422,
       -12.50116897,  16.42582899,  38.42106931, -56.49306168,
        39.13307637,   2.23930946,  -7.62133634,  55.75532152,
       -48.28959562,  18.61518136,  67.33488561, -60.23594632,
        23.79778305,   1.6093096 ,   8.93323579,  24.37116165,
         7.90152424,  15.17386128,   6.13314841,  -1.22074885,
        34.90685269])

In [75]:
neighbors = KNeighborsRegressor()
neighbors.fit(X=regressor.X_train, y=regressor.y_train)
neighbors.predict(regressor.X_test)

array([  91.26293789,   15.39028948,   29.76732615,  -15.08752955,
        -38.10502699,   20.84343405,   64.97911786, -126.19087911,
         65.99209955,   -8.3103608 ,  -28.47312865,   98.79294874,
       -108.9287639 ,   25.16440187,  121.54336397, -132.24000186,
         34.96368022,   -8.71465145,    6.39051636,   35.60117955,
          4.30162817,   18.60689429,   -1.36214743,  -14.69075085,
         56.39164492])