# The base classification model

In [1]:
import torch
from d2l import torch as d2l

In [2]:
class Classifier(d2l.Module):  # @save
    """
    The base class for classification models.
    """
    def validation_step(self, batch):
        """
        This returns the loss value and the classification accuracy on a validation batch
        """
        Y_hat = self(*batch[:-1])  # We unpcak the bathc and call the constructor
        self.plot("loss", self.loss(Y_hat, batch[-1]), train=False)
        self.plot("acc", self.accuracy(Y_hat, batch[-1]), train=False)

In [3]:
@d2l.add_to_class(d2l.Module)
def configure_optimizers(self):
    """We use SGD"""
    return torch.optim.SGD(self.parameters(), lr=self.lr)

In [4]:
@d2l.add_to_class(Classifier)  # @save
def accuracy(self, Y_hat, Y, averaged=True):
    """Compute the number of correct predictions."""
    Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
    preds = Y_hat.argmax(axis=1).type(Y.dtype)
    compare = (preds == Y.reshape(-1)).type(torch.float32)
    return compare.mean() if averaged else compare

- Given a predicted probability distribution y_hat, we choose the class with the highest probability, then compare it with the true distribution to get the accuracy of our classifier.

- The classification accuracy is the fraction of all predication that are correct