In [1]:
##
# y = 1.477x + 0.089
#
import numpy as np
data = []  # 列表，用来保存样本数据
for i in range(1000):
    x = np.random.uniform(-10., 10.)  # 通过均匀分布获取(-10, 10)之间的数据
    eps = np.random.normal(0., 0.01)  # 均值为0，方差为0.01随机采样噪音
    y = 1.477 * x + 0.089 + eps
    data.append([x, y])
data = np.array(data)   # 将数组转换为numpy数组

def mse(b, w, points):
    """计算真实值与预测值的误差"""
    totalError = 0
    for i in range(0, len(points)):
        x = points[i, 0]
        y = points[i, 1]
        totalError += (y - (w*x + b))**2    # 计算样本与实际值之间的误差，并叠加
    return totalError/float(len(points))

def step_gradient(b_current, w_current, points, lr):
    """求导，计算导数，然后更新b，w"""
    b_gradient = 0
    w_gradient = 0
    M = float(len(points))
    for i in range(0, len(points)):
        x = points[i, 0]
        y = points[i, 1]
        b_gradient += (2/M) * ((w_current * x + b_current) - y)
        w_gradient += (2/M) * x * ((w_current * x + b_current) - y)
    new_b = b_current - (lr * b_gradient)
    new_w = w_current - (lr * w_gradient)
    return [new_b, new_w]

def gradient_desent(points, starting_b, starting_w, lr, num_iterations):
    """反向传播算法"""
    b = starting_b
    w = starting_w
    for step in range(num_iterations):
        b, w = step_gradient(b, w, np.array(points), lr)
        loss = mse(b, w, points)
        if step % 50 ==0:
            print(f"interation: {step}, loss: {loss}, w: {w}, b: {b}")
    return [b, w]

def main():
    lr = 0.01
    initial_b = 0
    initial_w = 0
    num_interations = 1000
    [b, w] = gradient_desent(data, initial_b, initial_w, lr, num_interations)
    loss = mse(b, w, data)
    print(f'Final loss: {loss}, w: {w}, b: {b}')
main()

interation: 0, loss: 8.145128930554039, w: 0.9822406284088075, b: 0.0004636170984777005
interation: 50, loss: 0.0011597734068617845, w: 1.4769904141393744, b: 0.05656007484290752
interation: 100, loss: 0.0002423091374761839, w: 1.4770190214748495, b: 0.07723921627682867
interation: 150, loss: 0.00012062010536763872, w: 1.4770294400552493, b: 0.08477040693289756
interation: 200, loss: 0.0001044797269031638, w: 1.4770332344253203, b: 0.0875132109743178
interation: 250, loss: 0.00010233892742056139, w: 1.4770346163068901, b: 0.08851211995872504
interation: 300, loss: 0.00010205497977740142, w: 1.4770351195779567, b: 0.08887591523686417
interation: 350, loss: 0.00010201731802313095, w: 1.477035302865564, b: 0.0890084067916126
interation: 400, loss: 0.00010201232270878638, w: 1.4770353696175575, b: 0.08905665923779457
interation: 450, loss: 0.00010201166014896044, w: 1.4770353939281409, b: 0.08907423242249277
interation: 500, loss: 0.00010201157226949951, w: 1.477035402781876, b: 0.08908063