-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Copy pathcls_head.py
154 lines (128 loc) · 5.91 KB
/
cls_head.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcls.evaluation.metrics import Accuracy
from mmcls.registry import MODELS
from mmcls.structures import ClsDataSample
from .base_head import BaseHead
@MODELS.register_module()
class ClsHead(BaseHead):
"""Classification head.
Args:
loss (dict): Config of classification loss. Defaults to
``dict(type='CrossEntropyLoss', loss_weight=1.0)``.
topk (int | Tuple[int]): Top-k accuracy. Defaults to ``(1, )``.
cal_acc (bool): Whether to calculate accuracy during training.
If you use batch augmentations like Mixup and CutMix during
training, it is pointless to calculate accuracy.
Defaults to False.
init_cfg (dict, optional): the config to control the initialization.
Defaults to None.
"""
def __init__(self,
loss: dict = dict(type='CrossEntropyLoss', loss_weight=1.0),
topk: Union[int, Tuple[int]] = (1, ),
cal_acc: bool = False,
init_cfg: Optional[dict] = None):
super(ClsHead, self).__init__(init_cfg=init_cfg)
self.topk = topk
if not isinstance(loss, nn.Module):
loss = MODELS.build(loss)
self.loss_module = loss
self.cal_acc = cal_acc
def pre_logits(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The process before the final classification head.
The input ``feats`` is a tuple of tensor, and each tensor is the
feature of a backbone stage. In ``ClsHead``, we just obtain the feature
of the last stage.
"""
# The ClsHead doesn't have other module, just return after unpacking.
return feats[-1]
def forward(self, feats: Tuple[torch.Tensor]) -> torch.Tensor:
"""The forward process."""
pre_logits = self.pre_logits(feats)
# The ClsHead doesn't have the final classification head,
# just return the unpacked inputs.
return pre_logits
def loss(self, feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample], **kwargs) -> dict:
"""Calculate losses from the classification score.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample]): The annotation data of
every samples.
**kwargs: Other keyword arguments to forward the loss module.
Returns:
dict[str, Tensor]: a dictionary of loss components
"""
# The part can be traced by torch.fx
cls_score = self(feats)
# The part can not be traced by torch.fx
losses = self._get_loss(cls_score, data_samples, **kwargs)
return losses
def _get_loss(self, cls_score: torch.Tensor,
data_samples: List[ClsDataSample], **kwargs):
"""Unpack data samples and compute loss."""
# Unpack data samples and pack targets
if 'score' in data_samples[0].gt_label:
# Batch augmentation may convert labels to one-hot format scores.
target = torch.stack([i.gt_label.score for i in data_samples])
else:
target = torch.cat([i.gt_label.label for i in data_samples])
# compute loss
losses = dict()
loss = self.loss_module(
cls_score, target, avg_factor=cls_score.size(0), **kwargs)
losses['loss'] = loss
# compute accuracy
if self.cal_acc:
assert target.ndim == 1, 'If you enable batch augmentation ' \
'like mixup during training, `cal_acc` is pointless.'
acc = Accuracy.calculate(cls_score, target, topk=self.topk)
losses.update(
{f'accuracy_top-{k}': a
for k, a in zip(self.topk, acc)})
return losses
def predict(
self,
feats: Tuple[torch.Tensor],
data_samples: List[ClsDataSample] = None) -> List[ClsDataSample]:
"""Inference without augmentation.
Args:
feats (tuple[Tensor]): The features extracted from the backbone.
Multiple stage inputs are acceptable but only the last stage
will be used to classify. The shape of every item should be
``(num_samples, num_classes)``.
data_samples (List[ClsDataSample], optional): The annotation
data of every samples. If not None, set ``pred_label`` of
the input data samples. Defaults to None.
Returns:
List[ClsDataSample]: A list of data samples which contains the
predicted results.
"""
# The part can be traced by torch.fx
cls_score = self(feats)
# The part can not be traced by torch.fx
predictions = self._get_predictions(cls_score, data_samples)
return predictions
def _get_predictions(self, cls_score, data_samples):
"""Post-process the output of head.
Including softmax and set ``pred_label`` of data samples.
"""
pred_scores = F.softmax(cls_score, dim=1)
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()
if data_samples is not None:
for data_sample, score, label in zip(data_samples, pred_scores,
pred_labels):
data_sample.set_pred_score(score).set_pred_label(label)
else:
data_samples = []
for score, label in zip(pred_scores, pred_labels):
data_samples.append(ClsDataSample().set_pred_score(
score).set_pred_label(label))
return data_samples