# 2.3 The Base Classification Model

This section provides a base class for classification models to simplify future code.

## 2.3.1 The `Classifier` Class

We define the `Classifier` class below. In the `validation_step` we report both the loss value and the classification accuracy on a validation batch. We draw an update for every `num_val_batches` batches. This has the benefit of generating the averaged loss and accuracy on the whole validation data. These average numbers are not exactly correct if the final batch contains fewer examples, but we ignore this minor difference to keep the code simple.

In [1]:
import torch
from d2l import torch as d2l

class Classifier(d2l.Module):  #@save
    """The base class of classification models."""
    def validation_step(self, batch):
        Y_hat = self(*batch[:-1]) # 对输入数据进行前向传播得到预测
        self.plot('loss', self.loss(Y_hat, batch[-1]), train=False) # 计算损失并绘制到验证集损失图上
        self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False) # 计算准确率并绘制到验证集准确率图上

By default we use a stochastic gradient descent optimizer, operating on minibatches, just as we did in the context of linear regression.

In [2]:
@d2l.add_to_class(d2l.Module)  #@save
def configure_optimizers(self):
    return torch.optim.SGD(self.parameters(), lr=self.lr) # 使用随机梯度下降作为优化器

## 2.3.2 Accuracy

When predictions are consistent with the label class `y`, they are correct.
The classification accuracy is the fraction of all predictions that are correct.

Accuracy is computed as follows.
First, if `y_hat` is a matrix,
we assume that the second dimension stores prediction scores for each class.
We use `argmax` to obtain the predicted class by the index for the largest entry in each row.
Then we compare the predicted class with the ground truth `y` elementwise.
Since the equality operator `==` is sensitive to data types,
we convert `y_hat`'s data type to match that of `y`.
The result is a tensor containing entries of 0 (false) and 1 (true).
Taking the sum yields the number of correct predictions.


In [3]:
# 计算模型的预测精确度。
@d2l.add_to_class(Classifier)  #@save
def accuracy(self, Y_hat, Y, averaged=True):
    """Compute the number of correct predictions."""
    Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1])) # 重新调整Y_hat的形状，使其成为二维张量
    preds = Y_hat.argmax(axis=1).type(Y.dtype)   # 使用argmax获取每行的最大值的索引，这代表模型的预测类别
    compare = (preds == Y.reshape(-1)).type(torch.float32) # 比较预测值和真实标签，看它们是否相等
    return compare.mean() if averaged else compare  # 如果averaged=True，则返回平均精确度；否则，返回每个样本的比较结果