<a href="https://colab.research.google.com/github/dhdbsrlw/Dive-Into-DeepLearning/blob/main/4_3_The_Base_Classification_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install d2l==1.0.3

# The Base Classification Model


In [2]:
import torch
from d2l import torch as d2l

## The `Classifier` Class


In [3]:
class Classifier(d2l.Module):
    """The base class of classification models."""
    def validation_step(self, batch):
        Y_hat = self(*batch[:-1])
        self.plot('loss', self.loss(Y_hat, batch[-1]), train=False)
        self.plot('acc', self.accuracy(Y_hat, batch[-1]), train=False)

In [4]:
@d2l.add_to_class(d2l.Module)
def configure_optimizers(self):
    return torch.optim.SGD(self.parameters(), lr=self.lr)

## Accuracy


In [5]:
@d2l.add_to_class(Classifier)
def accuracy(self, Y_hat, Y, averaged=True):
    """Compute the number of correct predictions."""
    Y_hat = Y_hat.reshape((-1, Y_hat.shape[-1]))
    preds = Y_hat.argmax(axis=1).type(Y.dtype)
    compare = (preds == Y.reshape(-1)).type(torch.float32)
    return compare.mean() if averaged else compare

## Discussion (TakeAway Msg)

In [None]:
'''
- 예측 클래스를 결정할때, 각 클래스에 대한 일종의 예측점수인 entropy 값 (prediction score) 들 중에서 argmax 메소드를 통해 가장 높은 확률 값을 지닌 클래스를 선정한다.
- 각 validatino step 마다 loss value 와 cls accuracy 를 산출한다.
'''

## Summary

Classification is a sufficiently common problem that it warrants its own convenience functions. Of central importance in classification is the *accuracy* of the classifier. Note that while we often care primarily about accuracy, we train classifiers to optimize a variety of other objectives for statistical and computational reasons. However, regardless of which loss function was minimized during training, it is useful to have a convenience method for assessing the accuracy of our classifier empirically.
