In [None]:
%matplotlib inline


# Robust linear estimator fitting

Here a sine function is fit with a polynomial of order 3, for values
close to zero.

Robust fitting is demoed in different situations:

- No measurement errors, only modelling errors (fitting a sine with a
  polynomial)

- Measurement errors in X

- Measurement errors in y

The median absolute deviation to non corrupt new data is used to judge
the quality of the prediction.

What we can see that:

- RANSAC is good for strong outliers in the y direction

- TheilSen is good for small outliers, both in direction X and y, but has
  a break point above which it performs worse than OLS.

- The scores of HuberRegressor may not be compared directly to both TheilSen
  and RANSAC because it does not attempt to completely filter the outliers
  but lessen their effect.


In [92]:
from matplotlib import pyplot as plt
import numpy as np

from sklearn.linear_model import (
    LinearRegression, TheilSenRegressor, RANSACRegressor, HuberRegressor)
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline

np.random.seed(42)

a = 2
b = 5
n = 400

x1 = np.random.normal(size=n)
noise = .5*np.random.normal(size=n)
y = a * x1 + b + noise


data = pd.concat([pd.DataFrame(X),pd.DataFrame(y)], axis=1)
data.columns = ["y", "x1"]
data


In [82]:
import altair as alt

p1 = alt.Chart(data).mark_circle(size=60).encode(
    x='x1',
    y='y',
)
p1

In [37]:
x_plot = np.linspace(X.min(), X.max())

In [66]:
X = x1[:, np.newaxis]
X.size

400

In [68]:
model = LinearRegression()
model.fit(X, y)

In [69]:
mse = mean_squared_error(model.predict(X_test), y_test)
mse

25.612798571963104

In [74]:
x_plot = np.linspace(X.min(), X.max())
y_plot = model.predict(x_plot[:, np.newaxis])

In [91]:
data_plot = pd.concat([pd.DataFrame(x_plot),pd.DataFrame(y_plot)], axis=1)
data_plot.columns = ["x_plot", "y_plot"]
data_plot

Unnamed: 0,x_plot,y_plot
0,-3.241267,-1.307353
1,-3.096492,-1.026423
2,-2.951716,-0.745493
3,-2.806941,-0.464563
4,-2.662165,-0.183633
5,-2.51739,0.097296
6,-2.372614,0.378226
7,-2.227839,0.659156
8,-2.083063,0.940086
9,-1.938288,1.221016


In [89]:
p2 = alt.Chart(data_plot).mark_line().encode(
    x='x_plot',
    y='y_plot',
    color = alt.value("#FFAA00")
)

In [90]:
p1 + p2