-
Notifications
You must be signed in to change notification settings - Fork 298
/
segmentation.py
316 lines (278 loc) · 11.9 KB
/
segmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
"""Trainers for semantic segmentation."""
import os
from typing import Any
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import torch.nn as nn
from matplotlib.figure import Figure
from torch import Tensor
from torchmetrics import MetricCollection
from torchmetrics.classification import MulticlassAccuracy, MulticlassJaccardIndex
from torchvision.models._api import WeightsEnum
from ..datasets import RGBBandsMissingError, unbind_samples
from ..models import FCN, get_weight
from . import utils
from .base import BaseTask
class SemanticSegmentationTask(BaseTask):
"""Semantic Segmentation."""
def __init__(
self,
model: str = 'unet',
backbone: str = 'resnet50',
weights: WeightsEnum | str | bool | None = None,
in_channels: int = 3,
num_classes: int = 1000,
num_filters: int = 3,
loss: str = 'ce',
class_weights: Tensor | None = None,
ignore_index: int | None = None,
lr: float = 1e-3,
patience: int = 10,
freeze_backbone: bool = False,
freeze_decoder: bool = False,
) -> None:
"""Initialize a new SemanticSegmentationTask instance.
Args:
model: Name of the
`smp <https://smp.readthedocs.io/en/latest/models.html>`__ model to use.
backbone: Name of the `timm
<https://smp.readthedocs.io/en/latest/encoders_timm.html>`__ or `smp
<https://smp.readthedocs.io/en/latest/encoders.html>`__ backbone to use.
weights: Initial model weights. Either a weight enum, the string
representation of a weight enum, True for ImageNet weights, False or
None for random weights, or the path to a saved model state dict. FCN
model does not support pretrained weights. Pretrained ViT weight enums
are not supported yet.
in_channels: Number of input channels to model.
num_classes: Number of prediction classes (including the background).
num_filters: Number of filters. Only applicable when model='fcn'.
loss: Name of the loss function, currently supports
'ce', 'jaccard' or 'focal' loss.
class_weights: Optional rescaling weight given to each
class and used with 'ce' loss.
ignore_index: Optional integer class index to ignore in the loss and
metrics.
lr: Learning rate for optimizer.
patience: Patience for learning rate scheduler.
freeze_backbone: Freeze the backbone network to fine-tune the
decoder and segmentation head.
freeze_decoder: Freeze the decoder network to linear probe
the segmentation head.
.. versionchanged:: 0.3
*ignore_zeros* was renamed to *ignore_index*.
.. versionchanged:: 0.4
*segmentation_model*, *encoder_name*, and *encoder_weights*
were renamed to *model*, *backbone*, and *weights*.
.. versionadded:: 0.5
The *class_weights*, *freeze_backbone*, and *freeze_decoder* parameters.
.. versionchanged:: 0.5
The *weights* parameter now supports WeightEnums and checkpoint paths.
*learning_rate* and *learning_rate_schedule_patience* were renamed to
*lr* and *patience*.
.. versionchanged:: 0.6
The *ignore_index* parameter now works for jaccard loss.
"""
self.weights = weights
super().__init__(ignore='weights')
def configure_models(self) -> None:
"""Initialize the model.
Raises:
ValueError: If *model* is invalid.
"""
model: str = self.hparams['model']
backbone: str = self.hparams['backbone']
weights = self.weights
in_channels: int = self.hparams['in_channels']
num_classes: int = self.hparams['num_classes']
num_filters: int = self.hparams['num_filters']
if model == 'unet':
self.model = smp.Unet(
encoder_name=backbone,
encoder_weights='imagenet' if weights is True else None,
in_channels=in_channels,
classes=num_classes,
)
elif model == 'deeplabv3+':
self.model = smp.DeepLabV3Plus(
encoder_name=backbone,
encoder_weights='imagenet' if weights is True else None,
in_channels=in_channels,
classes=num_classes,
)
elif model == 'fcn':
self.model = FCN(
in_channels=in_channels, classes=num_classes, num_filters=num_filters
)
else:
raise ValueError(
f"Model type '{model}' is not valid. "
"Currently, only supports 'unet', 'deeplabv3+' and 'fcn'."
)
if model != 'fcn':
if weights and weights is not True:
if isinstance(weights, WeightsEnum):
state_dict = weights.get_state_dict(progress=True)
elif os.path.exists(weights):
_, state_dict = utils.extract_backbone(weights)
else:
state_dict = get_weight(weights).get_state_dict(progress=True)
self.model.encoder.load_state_dict(state_dict)
# Freeze backbone
if self.hparams['freeze_backbone'] and model in ['unet', 'deeplabv3+']:
for param in self.model.encoder.parameters():
param.requires_grad = False
# Freeze decoder
if self.hparams['freeze_decoder'] and model in ['unet', 'deeplabv3+']:
for param in self.model.decoder.parameters():
param.requires_grad = False
def configure_losses(self) -> None:
"""Initialize the loss criterion.
Raises:
ValueError: If *loss* is invalid.
"""
loss: str = self.hparams['loss']
ignore_index = self.hparams['ignore_index']
if loss == 'ce':
ignore_value = -1000 if ignore_index is None else ignore_index
self.criterion = nn.CrossEntropyLoss(
ignore_index=ignore_value, weight=self.hparams['class_weights']
)
elif loss == 'jaccard':
# JaccardLoss requires a list of classes to use instead of a class
# index to ignore.
classes = [
i for i in range(self.hparams['num_classes']) if i != ignore_index
]
self.criterion = smp.losses.JaccardLoss(mode='multiclass', classes=classes)
elif loss == 'focal':
self.criterion = smp.losses.FocalLoss(
'multiclass', ignore_index=ignore_index, normalized=True
)
else:
raise ValueError(
f"Loss type '{loss}' is not valid. "
"Currently, supports 'ce', 'jaccard' or 'focal' loss."
)
def configure_metrics(self) -> None:
"""Initialize the performance metrics.
* :class:`~torchmetrics.classification.MulticlassAccuracy`: Overall accuracy
(OA) using 'micro' averaging. The number of true positives divided by the
dataset size. Higher values are better.
* :class:`~torchmetrics.classification.MulticlassJaccardIndex`: Intersection
over union (IoU). Uses 'micro' averaging. Higher valuers are better.
.. note::
* 'Micro' averaging suits overall performance evaluation but may not reflect
minority class accuracy.
* 'Macro' averaging, not used here, gives equal weight to each class, useful
for balanced performance assessment across imbalanced classes.
"""
num_classes: int = self.hparams['num_classes']
ignore_index: int | None = self.hparams['ignore_index']
metrics = MetricCollection(
[
MulticlassAccuracy(
num_classes=num_classes,
ignore_index=ignore_index,
multidim_average='global',
average='micro',
),
MulticlassJaccardIndex(
num_classes=num_classes, ignore_index=ignore_index, average='micro'
),
]
)
self.train_metrics = metrics.clone(prefix='train_')
self.val_metrics = metrics.clone(prefix='val_')
self.test_metrics = metrics.clone(prefix='test_')
def training_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute the training loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
Returns:
The loss tensor.
"""
x = batch['image']
y = batch['mask']
batch_size = x.shape[0]
y_hat = self(x)
loss: Tensor = self.criterion(y_hat, y)
self.log('train_loss', loss, batch_size=batch_size)
self.train_metrics(y_hat, y)
self.log_dict(self.train_metrics, batch_size=batch_size)
return loss
def validation_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Compute the validation loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
x = batch['image']
y = batch['mask']
batch_size = x.shape[0]
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('val_loss', loss, batch_size=batch_size)
self.val_metrics(y_hat, y)
self.log_dict(self.val_metrics, batch_size=batch_size)
if (
batch_idx < 10
and hasattr(self.trainer, 'datamodule')
and hasattr(self.trainer.datamodule, 'plot')
and self.logger
and hasattr(self.logger, 'experiment')
and hasattr(self.logger.experiment, 'add_figure')
):
datamodule = self.trainer.datamodule
batch['prediction'] = y_hat.argmax(dim=1)
for key in ['image', 'mask', 'prediction']:
batch[key] = batch[key].cpu()
sample = unbind_samples(batch)[0]
fig: Figure | None = None
try:
fig = datamodule.plot(sample)
except RGBBandsMissingError:
pass
if fig:
summary_writer = self.logger.experiment
summary_writer.add_figure(
f'image/{batch_idx}', fig, global_step=self.global_step
)
plt.close()
def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""Compute the test loss and additional metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
x = batch['image']
y = batch['mask']
batch_size = x.shape[0]
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('test_loss', loss, batch_size=batch_size)
self.test_metrics(y_hat, y)
self.log_dict(self.test_metrics, batch_size=batch_size)
def predict_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute the predicted class probabilities.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
Returns:
Output predicted probabilities.
"""
x = batch['image']
y_hat: Tensor = self(x).softmax(dim=1)
return y_hat