-
Notifications
You must be signed in to change notification settings - Fork 400
/
classification.py
94 lines (71 loc) · 3.71 KB
/
classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""A convenience class that creates a :class:`.ComposerModel` for classification tasks from a vanilla PyTorch model.
:class:`.ComposerClassifier` requires batches in the form: (``input``, ``target``) and includes a basic
classification training loop with :func:`.soft_cross_entropy` loss and accuracy logging.
"""
import logging
from typing import Any, Optional, Tuple, Union
import torch
from torch import Tensor
from torchmetrics import Metric, MetricCollection
from torchmetrics.classification import Accuracy
from composer.loss import loss_registry
from composer.metrics import CrossEntropy
from composer.models import ComposerModel
__all__ = ['ComposerClassifier']
log = logging.getLogger(__name__)
class ComposerClassifier(ComposerModel):
"""A convenience class that creates a :class:`.ComposerModel` for classification tasks from a vanilla PyTorch model.
:class:`.ComposerClassifier` requires batches in the form: (``input``, ``target``) and includes a basic
classification training loop with :func:`.soft_cross_entropy` loss and accuracy logging.
Args:
module (torch.nn.Module): A PyTorch neural network module.
loss_name (str, optional): Loss function to use. E.g. 'soft_cross_entropy' or
'binary_cross_entropy_with_logits'. Loss function must be in
:mod:`~composer.loss.loss`. Default: ``'soft_cross_entropy'``".
Returns:
ComposerClassifier: An instance of :class:`.ComposerClassifier`.
Example:
.. testcode::
import torchvision
from composer.models import ComposerClassifier
pytorch_model = torchvision.models.resnet18(pretrained=False)
model = ComposerClassifier(pytorch_model)
"""
num_classes: Optional[int] = None
def __init__(self, module: torch.nn.Module, loss_name: str = 'soft_cross_entropy') -> None:
super().__init__()
self.train_acc = Accuracy()
self.val_acc = Accuracy()
self.val_loss = CrossEntropy()
self.module = module
if loss_name not in loss_registry.keys():
raise ValueError(f'Unrecognized loss function: {loss_name}. Please ensure the '
'specified loss function is present in composer.loss.loss.py')
self._loss_fxn = loss_registry[loss_name]
if hasattr(self.module, 'num_classes'):
self.num_classes = getattr(self.module, 'num_classes')
if loss_name == 'binary_cross_entropy_with_logits':
log.warning('UserWarning: Using `binary_cross_entropy_loss_with_logits` '
'without using `initializers.linear_log_constant_bias` can degrade '
'performance. '
'Please ensure you are using `initializers. '
'linear_log_constant_bias`.')
def loss(self, outputs: Any, batch: Any, *args, **kwargs) -> Tensor:
_, targets = batch
if not isinstance(outputs, Tensor): # to pass typechecking
raise ValueError('Loss expects input as Tensor')
if not isinstance(targets, Tensor):
raise ValueError('Loss does not support multiple target Tensors')
return self._loss_fxn(outputs, targets, *args, **kwargs)
def metrics(self, train: bool = False) -> Union[Metric, MetricCollection]:
return self.train_acc if train else MetricCollection([self.val_acc, self.val_loss])
def forward(self, batch: Tuple[Tensor, Tensor]) -> Tensor:
inputs, _ = batch
outputs = self.module(inputs)
return outputs
def validate(self, batch: Any) -> Tuple[Any, Any]:
_, targets = batch
outputs = self.forward(batch)
return outputs, targets