In [3]:
import numpy as np
import torch
from torch import nn
from d2l import torch as d2l


class LinearRegression(d2l.Module):  #@save
    """The linear regression model implemented with high-level APIs."""
    def __init__(self, lr):
        super().__init__()
        self.save_hyperparameters()
        self.net = nn.LazyLinear(1) # LazyLinear 类不用指定输入维度，只需指定输出维度即可，简化操作
        self.net.weight.data.normal_(0, 0.01)
        self.net.bias.data.fill_(0)


@d2l.add_to_class(LinearRegression)  #@save
def forward(self, X):
    return self.net(X)


# 计算损失，均方误差
@d2l.add_to_class(LinearRegression)  #@save
def loss(self, y_hat, y):
    fn = nn.MSELoss()
    return fn(y_hat, y)


# 优化算法
@d2l.add_to_class(LinearRegression)  #@save
def configure_optimizers(self):
    return torch.optim.SGD(self.parameters(), self.lr)


# 训练
model = LinearRegression(lr=0.03)
data = d2l.SyntheticRegressionData(w=torch.tensor([2, -3.4]), b=4.2)
trainer = d2l.Trainer(max_epochs=3)
trainer.fit(model, data)