In [1]:
import torch

# 数据
x_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]

# 初始化权重和偏置
w1 = torch.tensor([1.0], requires_grad=True)
w2 = torch.tensor([1.0], requires_grad=True)
b = torch.tensor([1.0], requires_grad=True)

# 定义模型的前向传播
def forward(x):
    return w1 * x ** 2 + w2 * x + b

# 定义损失函数
def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2

# 训练前的预测
print("Predict (before training)", 4, forward(4).item())

# 训练过程
learning_rate = 0.01
for epoch in range(100):
    for x, y in zip(x_data, y_data):
        l = loss(x, y)
        l.backward()
        with torch.no_grad():
            w1 -= learning_rate * w1.grad
            w2 -= learning_rate * w2.grad
            b -= learning_rate * b.grad

            # 清零累积的梯度
            w1.grad.zero_()
            w2.grad.zero_()
            b.grad.zero_()

    if epoch % 10 == 0: # 每10个epoch打印一次
        print(f"Epoch {epoch}: w1 = {w1.item()}, w2 = {w2.item()}, b = {b.item()}, Loss = {l.item()}")

# 训练后的预测
print("Predict (after training)", 4, forward(4).item())

# 打印最终的参数值
print(f"Final parameters: w1 = {w1.item()}, w2 = {w2.item()}, b = {b.item()}")

Predict (before training) 4 21.0
Epoch 0: w1 = -0.01927196979522705, w2 = 0.6087759733200073, b = 0.8371919989585876, Loss = 18.321826934814453
Epoch 10: w1 = 0.2856519818305969, w2 = 0.7844908833503723, b = 0.937275230884552, Loss = 0.02848036028444767
Epoch 20: w1 = 0.27328190207481384, w2 = 0.8200692534446716, b = 0.9667862057685852, Loss = 0.019148115068674088
Epoch 30: w1 = 0.2643830478191376, w2 = 0.8460506796836853, b = 0.9847812652587891, Loss = 0.014172340743243694
Epoch 40: w1 = 0.2579308748245239, w2 = 0.8655898571014404, b = 0.9950385093688965, Loss = 0.011208509095013142
Epoch 50: w1 = 0.25313133001327515, w2 = 0.8807737827301025, b = 1.000089168548584, Loss = 0.00937769003212452
Epoch 60: w1 = 0.24944846332073212, w2 = 0.8930094838142395, b = 1.0016393661499023, Loss = 0.00820931326597929
Epoch 70: w1 = 0.24652099609375, w2 = 0.9032458066940308, b = 1.0008395910263062, Loss = 0.007440924644470215
Epoch 80: w1 = 0.24410557746887207, w2 = 0.9121231436729431, b = 0.998465955