-
Notifications
You must be signed in to change notification settings - Fork 400
/
classification.py
95 lines (70 loc) · 3.58 KB
/
classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""A convenience class that creates a :class:`.ComposerModel` for classification tasks from a vanilla PyTorch model.
:class:`.ComposerClassifier` requires batches in the form: (``input``, ``target``) and includes a basic
classification training loop with :func:`.soft_cross_entropy` loss and accuracy logging.
"""
import logging
from typing import Any, Callable, Optional, Tuple, Union
import torch
from torch import Tensor
from torchmetrics import Metric, MetricCollection
from torchmetrics.classification import Accuracy
from composer.loss import soft_cross_entropy
from composer.metrics import CrossEntropy
from composer.models import ComposerModel
__all__ = ['ComposerClassifier']
log = logging.getLogger(__name__)
class ComposerClassifier(ComposerModel):
"""A convenience class that creates a :class:`.ComposerModel` for classification tasks from a vanilla PyTorch model.
:class:`.ComposerClassifier` requires batches in the form: (``input``, ``target``) and includes a basic
classification training loop with a loss function `loss_fn` which takes in the model's outputs and the labels.
Args:
module (torch.nn.Module): A PyTorch neural network module.
train_metrics (Metric | MetricCollection, optional): A torchmetric or collection of torchmetrics to be
computed on the training set throughout training.
val_metrics (Metric | MetricCollection, optional): A torchmetric or collection of torchmetrics to be
computed on the validation set throughout training.
loss_fn (Callable, optional): Loss function to use. This loss function should have at least two arguments:
1) the output of the model and 2) ``target`` i.e. labels from the dataset.
Returns:
ComposerClassifier: An instance of :class:`.ComposerClassifier`.
Example:
.. testcode::
import torchvision
from composer.models import ComposerClassifier
pytorch_model = torchvision.models.resnet18(pretrained=False)
model = ComposerClassifier(pytorch_model)
"""
num_classes: Optional[int] = None
def __init__(self,
module: torch.nn.Module,
train_metrics: Optional[Union[Metric, MetricCollection]] = None,
val_metrics: Optional[Union[Metric, MetricCollection]] = None,
loss_fn: Callable = soft_cross_entropy) -> None:
super().__init__()
# Metrics for training
if train_metrics is None:
train_metrics = Accuracy()
self.train_metrics = train_metrics
# Metrics for validation
if val_metrics is None:
val_metrics = MetricCollection([CrossEntropy(), Accuracy()])
self.val_metrics = val_metrics
self.module = module
self._loss_fn = loss_fn
if hasattr(self.module, 'num_classes'):
self.num_classes = getattr(self.module, 'num_classes')
def loss(self, outputs: Tensor, batch: Tuple[Any, Tensor], *args, **kwargs) -> Tensor:
_, targets = batch
return self._loss_fn(outputs, targets, *args, **kwargs)
def metrics(self, train: bool = False) -> Union[Metric, MetricCollection]:
return self.train_metrics if train else self.val_metrics
def forward(self, batch: Tuple[Tensor, Any]) -> Tensor:
inputs, _ = batch
outputs = self.module(inputs)
return outputs
def validate(self, batch: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
_, targets = batch
outputs = self.forward(batch)
return outputs, targets