In [1]:
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression

In [2]:
boston = datasets.load_boston()  # 波士顿房产数据集

X = boston.data
y = boston.target

X = X[y < 50.0]
y = y[y < 50.0]  # 剔除异常值(上限点)

In [3]:
X.shape

(490, 13)

In [13]:
y.shape

(490,)

In [4]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.333)

In [5]:
lin_reg = LinearRegression(n_jobs=-1)  # 多元线性回归
lin_reg.fit(X_train, y_train)

LinearRegression(copy_X=True, fit_intercept=True, n_jobs=None, normalize=False)

In [6]:
lin_reg.coef_  # 解释变量系数(多元线性方程形如:y=ax+b)

array([-9.92856151e-02,  4.61416564e-02, -1.71870780e-02, -2.69334762e-01,
       -9.71471629e+00,  3.71257623e+00, -3.40431851e-02, -1.25568474e+00,
        2.43212574e-01, -1.51248394e-02, -6.43148736e-01,  9.68896579e-03,
       -3.34079079e-01])

In [7]:
correlation = np.argsort(abs(lin_reg.coef_))  # 解释变量相关性排序
correlation

array([11,  9,  2,  6,  1,  0,  8,  3, 12, 10,  7,  5,  4], dtype=int64)

In [8]:
boston.feature_names  # 数据集所有的特征

array(['CRIM', 'ZN', 'INDUS', 'CHAS', 'NOX', 'RM', 'AGE', 'DIS', 'RAD',
       'TAX', 'PTRATIO', 'B', 'LSTAT'], dtype='<U7')

In [9]:
boston.feature_names[correlation]  # 按照解释变量相关性对数据集的特征进行排序(线性回归具有可解释性)

array(['B', 'TAX', 'INDUS', 'AGE', 'ZN', 'CRIM', 'RAD', 'CHAS', 'LSTAT',
       'PTRATIO', 'DIS', 'RM', 'NOX'], dtype='<U7')

In [10]:
lin_reg.intercept_  # 截距项

27.67807615829516

In [11]:
lin_reg.score(X_test, y_test)  # 模型的可决系数R平方

0.7995219970074476

In [12]:
lin_reg.predict(X_test)  # Predict using the linear model

array([22.07854147, 22.28066198, 28.62204763, 25.7706585 , 31.78689176,
       17.39965117, 14.28214238, 17.02168021, 23.87505835, 23.07038237,
       17.0320314 , 17.0839638 , 24.08801533, 19.06883787, 11.85669081,
       18.15847895, 23.10650584, 17.94816455, 20.3193119 , 16.62895394,
       21.62884567, 21.77301947, 12.63764547, 20.49389559, 21.00993234,
       21.05690724, 12.87995269, 23.95353488, 34.7121938 , 25.97417333,
       14.46214418, 32.96263134, 12.45344049, 21.44099076, 17.1464691 ,
       35.12582581, 29.96005797, 29.19330619, 20.84878274, 30.06149777,
       31.61833596, 13.96254948, 24.95473633, 19.84018356, 21.83789403,
       31.93001931, 19.33371013, 35.49134706, 10.62770666, 23.50697769,
       31.30349768, 17.75881753, 11.43715452, 21.83810829, 16.24200964,
       27.49238565, 19.50155802,  8.19522699, 32.88754628, 32.93351131,
       18.91701325, 25.42163883,  8.52737421, 27.75141997, 17.45251127,
       16.69542569, 26.54912697, 23.33743344,  7.94916852, 14.87