In [1]:
%matplotlib inline
import numpy as np
import scipy as sp
import matplotlib as mpl
import matplotlib.pyplot as plt
import IPython as ip
import pandas as pd
import seaborn as sns
import statsmodels.api as sm
import statsmodels.formula.api as smf

In [2]:
mpl.style.use('ggplot')
mpl.rc('figure', figsize=(7.2, 5.76))
mpl.rc('font', family='Noto Sans CJK TC')
plt.rc('lines', markeredgecolor='white', markeredgewidth=0.75)
plt.rc('patch', edgecolor='white', force_edgecolor=True, linewidth=1)
ip.display.set_matplotlib_formats('svg')

In [3]:
from sklearn import preprocessing
from sklearn import model_selection
from sklearn import neighbors
from sklearn import metrics

In [4]:
m = sm.datasets.anes96
df = m.load_pandas().data
df_raw = df

In [5]:
df = df_raw
df_X = df[[c for c in df.columns if c != 'vote']]
s_y = df.vote
X_raw = df_X.values
y_raw = s_y.values

In [6]:
X = X_raw
y = y_raw

X = preprocessing.scale(X)

X_train, X_test, y_train, y_test = model_selection.train_test_split(
    X, y, test_size=0.4, random_state=20200502
)

print('The model:\n')
nbrs = neighbors.NearestNeighbors()
display(nbrs)
print()

print('The training time:\n')
%time nbrs.fit(X_train)
print()

n = 3
indices = nbrs.kneighbors(X_test, n, return_distance=False)

pseudo_y_true = (y_test.repeat(n)
                       .reshape((y_test.shape[0], n))
                       .flatten())
pseudo_y_pred = y_train[indices.flatten()]

print('The pseudo metrics:\n')
print(metrics.classification_report(pseudo_y_true, pseudo_y_pred))

The model:



NearestNeighbors(algorithm='auto', leaf_size=30, metric='minkowski',
                 metric_params=None, n_jobs=None, n_neighbors=5, p=2,
                 radius=1.0)


The training time:

CPU times: user 418 µs, sys: 52 µs, total: 470 µs
Wall time: 453 µs

The pseudo metrics:

              precision    recall  f1-score   support

         0.0       0.90      0.85      0.88       663
         1.0       0.81      0.87      0.84       471

    accuracy                           0.86      1134
   macro avg       0.85      0.86      0.86      1134
weighted avg       0.86      0.86      0.86      1134

