# 多項式回帰分析
Rデータセットのcarsを用いる。
このデータセットを使用するには，パッケージrpy2を事前にインストールすること<br>
Documentation for rpy2 https://rpy2.readthedocs.io/en/version_2.8.x/<br>

carsの説明 : 次のサイトからcarsを検索  
https://stat.ethz.ch/R-manual/R-devel/library/datasets/html/00Index.html

In [1]:
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

import statsmodels.formula.api as smf

from rpy2.robjects import r, pandas2ri
pandas2ri.activate()

FLAG_fig = False

ModuleNotFoundError: No module named 'rpy2'

In [None]:
df = r['cars']  # read datasets of cars
x = df.speed
df.head()

#### 1次モデル
$y = b_0 + b_1 x$

In [None]:
result1 = smf.ols('dist ~ speed', data=df).fit()
print(result1.summary())
b0, b1 = result1.params

In [None]:
df.plot(kind='scatter', x='speed', y='dist')
plt.plot(x, b0+b1*x)

if FLAG_fig: plt.savefig('fig_REG_poy_R_cars_01.eps')
plt.show()

#### 2次モデル
$y = b_0 + b_1 x + b_2 x^2$

df.plot(kind='scatter', x='speed', y='dist')
plt.plot(x, b0+x*b1)

In [None]:
result2 = smf.ols('dist ~ np.power(speed,2) + speed', data=df).fit()
print(result2.summary())
b0, b2, b1 = result2.params

In [None]:
df.plot(kind='scatter', x='speed', y='dist')
plt.plot(x, b0+b1*x+b2*(x**2))

if FLAG_fig: plt.savefig('fig_REG_poy_R_cars_02.eps')
plt.show()

#### 3次モデル
$y = b_0 + b_1  + b_2 x^2 + b_3 x^3$

In [None]:
result3 = smf.ols('dist ~ np.power(speed,3) + np.power(speed,2) + speed', data=df).fit()
print(result3.summary())
b0, b3, b2, b1 = result3.params

In [None]:
df.plot(kind='scatter', x='speed', y='dist')
plt.plot(x, b0+b1*x+b2*(x**2) + b3*(x**3))

if FLAG_fig: plt.savefig('fig_REG_poy_R_cars_03.eps')
plt.show()

## nupmy.polyfit（）を用いたカーブフィッティングの例

In [None]:
x = df.speed
y = df.dist
degree = 2
fit = np.polyfit(x, y, degree)
print(fit)
est = np.poly1d(fit)
print(est)
plt.scatter(x,y)
plt.plot(x,est(x))
plt.title('degree = %d' % degree)

In [None]:
degree = 3
fit = np.polyfit(x, y, degree)
print(fit)
est = np.poly1d(fit)
print(est)
plt.scatter(x,y)
plt.plot(x,est(x))
plt.title('degree = %d' % degree)

#### 次の例は，オーバーフィッティング（over fitting）を示す

In [None]:
degree = 9
fit = np.polyfit(x, y, degree)
print(fit)
est = np.poly1d(fit)
print(est)
plt.scatter(x,y)
plt.plot(x,est(x))
plt.title('degree = %d' % degree)