In [1]:
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import SGDRegressor
from matplotlib import pyplot as plt
from sklearn.metrics import mean_squared_error
import pandas as pd

import plotly.graph_objs as go
from plotly.offline import init_notebook_mode, iplot
import plotly.io as pio
pio.renderers.default = 'iframe'

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
X = 2 * np.random.rand(1000, 1)
y = 4 + 3 * X 

In [4]:
GD_model = SGDRegressor(max_iter=1, eta0=0.0001, warm_start=True)

In [5]:
out = []

for i in range(100):
    GD_model.fit(X, y) # Singel itteration of GD
    out.append([
                i, GD_model.coef_[0], # Model Parameter (1)
                GD_model.intercept_[0], # Model Parameter (2)
                mean_squared_error(y, GD_model.predict(X))
                ])
    
    df = pd.DataFrame(out, columns=["itteration", "m", "b", "MSE"])
    _out = out[-1]
    print(f"Itteration Number: {_out[0]}, Param (1): {_out[1]}, Param (2): {_out[2]}, MSE: {_out[3]}")
    
    # input()
    # break

Itteration Number: 0, Param (1): 0.18396618778938328, Param (2): 0.16142218006617512, MSE: 46.75320620266479
Itteration Number: 1, Param (1): 0.3593716570342565, Param (2): 0.3150383811906591, MSE: 42.182721307851025
Itteration Number: 2, Param (1): 0.5234095573391275, Param (2): 0.46043925567351335, MSE: 38.10098684398194
Itteration Number: 3, Param (1): 0.679471589958668, Param (2): 0.5988273539165128, MSE: 34.41056103162077
Itteration Number: 4, Param (1): 0.8274031124005242, Param (2): 0.7303163357655722, MSE: 31.083271402157685
Itteration Number: 5, Param (1): 0.9686228463418723, Param (2): 0.8555827731589665, MSE: 28.068142293553255
Itteration Number: 6, Param (1): 1.1013290176303465, Param (2): 0.9743896669539717, MSE: 25.36493237677575
Itteration Number: 7, Param (1): 1.2292987442748926, Param (2): 1.0879427142811984, MSE: 22.898009079501232
Itteration Number: 8, Param (1): 1.3502595711226082, Param (2): 1.1958566842207843, MSE: 20.677776779448624
Itteration Number: 9, Param (1

In [6]:
df

Unnamed: 0,itteration,m,b,MSE
0,0,0.183966,0.161422,46.753206
1,1,0.359372,0.315038,42.182721
2,2,0.523410,0.460439,38.100987
3,3,0.679472,0.598827,34.410561
4,4,0.827403,0.730316,31.083271
...,...,...,...,...
95,95,3.477354,3.384322,0.094576
96,96,3.476778,3.387420,0.093695
97,97,3.476281,3.390601,0.092807
98,98,3.475720,3.393754,0.091939


In [7]:
init_notebook_mode(connected=True)

trace = go.Scatter3d(
    x=df['b'],
    y=df['m'],
    z=df['MSE'],
    mode='lines+markers',
    marker=dict(size=5, color=df['MSE'], colorscale='Viridis', opacity=0.8),
    line=dict(color='blue', width=2)
)

layout = go.Layout(
    title='',
    scene=dict(
        xaxis=dict(title='b'),
        yaxis=dict(title='m'),
        zaxis=dict(title='MSE')
    ),
    width=800,
    height=600
)

fig = go.Figure(data=[trace], layout=layout)
iplot(fig)