Predict Cereal Rating with Linear Regression

In [38]:
import pandas as pd
import statsmodels.api as sm
import numpy as np
from sklearn.model_selection import train_test_split

In [2]:
data = pd.read_csv("../data/cereal.csv")

In [3]:
data

Unnamed: 0,name,mfr,type,calories,protein,fat,sodium,fiber,carbs,sugars,potass,vitamins,shelf,weight,cups,rating
0,100% Bran,N,C,70,4,1,130,10.0,5.0,6,280,25,3,1.0,0.33,68.402973
1,100% Natural Bran,Q,C,120,3,5,15,2.0,8.0,8,135,0,3,1.0,1.00,33.983679
2,All-Bran,K,C,70,4,1,260,9.0,7.0,5,320,25,3,1.0,0.33,59.425505
3,All-Bran with Extra Fiber,K,C,50,4,0,140,14.0,8.0,0,330,25,3,1.0,0.50,93.704912
4,Almond Delight,R,C,110,2,2,200,1.0,14.0,8,-1,25,3,1.0,0.75,34.384843
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
72,Triples,G,C,110,2,1,250,0.0,21.0,3,60,25,3,1.0,0.75,39.106174
73,Trix,G,C,110,1,1,140,0.0,13.0,12,25,25,2,1.0,1.00,27.753301
74,Wheat Chex,R,C,100,3,1,230,3.0,17.0,3,115,25,1,1.0,0.67,49.787445
75,Wheaties,G,C,100,3,1,200,3.0,17.0,3,110,25,1,1.0,1.00,51.592193


In [12]:
data_prep = pd.concat([data, pd.get_dummies(data['type']), pd.get_dummies(data['mfr'])], axis=1)

In [13]:
data_prep

Unnamed: 0,name,mfr,type,calories,protein,fat,sodium,fiber,carbs,sugars,...,rating,C,H,A,G,K,N,P,Q,R
0,100% Bran,N,C,70,4,1,130,10.0,5.0,6,...,68.402973,1,0,0,0,0,1,0,0,0
1,100% Natural Bran,Q,C,120,3,5,15,2.0,8.0,8,...,33.983679,1,0,0,0,0,0,0,1,0
2,All-Bran,K,C,70,4,1,260,9.0,7.0,5,...,59.425505,1,0,0,0,1,0,0,0,0
3,All-Bran with Extra Fiber,K,C,50,4,0,140,14.0,8.0,0,...,93.704912,1,0,0,0,1,0,0,0,0
4,Almond Delight,R,C,110,2,2,200,1.0,14.0,8,...,34.384843,1,0,0,0,0,0,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
72,Triples,G,C,110,2,1,250,0.0,21.0,3,...,39.106174,1,0,0,1,0,0,0,0,0
73,Trix,G,C,110,1,1,140,0.0,13.0,12,...,27.753301,1,0,0,1,0,0,0,0,0
74,Wheat Chex,R,C,100,3,1,230,3.0,17.0,3,...,49.787445,1,0,0,0,0,0,0,0,1
75,Wheaties,G,C,100,3,1,200,3.0,17.0,3,...,51.592193,1,0,0,1,0,0,0,0,0


In [14]:
X = data_prep.drop(['name', 'mfr', 'type', 'rating'], axis=1)
y = data_prep['rating']

In [15]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

In [23]:
model = sm.OLS(y_train.values, sm.add_constant(X_train.values))
results = model.fit()
print(results.summary())

                            OLS Regression Results                            
Dep. Variable:                      y   R-squared:                       1.000
Model:                            OLS   Adj. R-squared:                  1.000
Method:                 Least Squares   F-statistic:                 4.888e+15
Date:                Tue, 09 Feb 2021   Prob (F-statistic):          1.30e-234
Time:                        12:03:34   Log-Likelihood:                 702.69
No. Observations:                  51   AIC:                            -1365.
Df Residuals:                      31   BIC:                            -1327.
Df Model:                          19                                         
Covariance Type:            nonrobust                                         
                 coef    std err          t      P>|t|      [0.025      0.975]
------------------------------------------------------------------------------
const         33.4339   3.36e-07   9.95e+07      0.0

In [33]:
pred = results.predict(sm.add_constant(X_test.values, has_constant="add"))

In [43]:
np.sqrt(np.mean((pred - y_test.values)**2))

3.433656603230979e-07

In [44]:
y_test

4     34.384843
35    21.871292
10    18.042851
0     68.402973
45    34.139765
47    40.105965
66    31.230054
53    41.503540
50    59.642837
28    41.015492
68    59.363993
74    49.787445
18    22.396513
12    19.823573
58    39.259197
33    53.371007
9     53.313813
5     29.509541
34    45.811716
22    36.176196
30    35.252444
40    39.241114
39    36.471512
16    45.863324
65    72.801787
54    60.756112
Name: rating, dtype: float64

In [45]:
pred

array([34.3848433 , 21.87129191, 18.04285048, 68.40297285, 34.13976421,
       40.10596493, 31.23005441, 41.50354016, 59.64283677, 41.01549166,
       59.36399367, 49.78744513, 22.39651281, 19.82357265, 39.25919728,
       53.37100735, 53.31381308, 29.50954065, 45.81171614, 36.17619609,
       35.25244412, 39.24111435, 36.47151208, 45.86332461, 72.80178716,
       60.75611159])