-
Notifications
You must be signed in to change notification settings - Fork 378
/
semantic_segmentation_learner.py
129 lines (104 loc) · 3.91 KB
/
semantic_segmentation_learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
from typing import TYPE_CHECKING, Optional, Tuple
import warnings
import logging
import torch
from torch.nn import functional as F
from rastervision.pytorch_learner.learner import Learner
from rastervision.pytorch_learner.utils import (
compute_conf_mat_metrics, compute_conf_mat, aggregate_metrics)
from rastervision.pytorch_learner.dataset.visualizer import (
SemanticSegmentationVisualizer)
if TYPE_CHECKING:
from torch import nn
warnings.filterwarnings('ignore')
log = logging.getLogger(__name__)
class SemanticSegmentationLearner(Learner):
def get_visualizer_class(self):
return SemanticSegmentationVisualizer
def train_step(self, batch, batch_ind):
x, y = batch
out = self.post_forward(self.model(x))
return {'train_loss': self.loss(out, y)}
def validate_step(self, batch, batch_ind):
x, y = batch
out = self.post_forward(self.model(x))
val_loss = self.loss(out, y)
num_labels = len(self.cfg.data.class_names)
y = y.view(-1)
out = self.prob_to_pred(out).view(-1)
conf_mat = compute_conf_mat(out, y, num_labels)
return {'val_loss': val_loss, 'conf_mat': conf_mat}
def validate_end(self, outputs):
metrics = aggregate_metrics(outputs, exclude_keys={'conf_mat'})
conf_mat = sum([o['conf_mat'] for o in outputs])
conf_mat_metrics = compute_conf_mat_metrics(conf_mat,
self.cfg.data.class_names)
metrics.update(conf_mat_metrics)
return metrics
def post_forward(self, x):
if isinstance(x, dict):
return x['out']
return x
def predict(self,
x: torch.Tensor,
raw_out: bool = False,
out_shape: Optional[Tuple[int, int]] = None) -> torch.Tensor:
if out_shape is None:
out_shape = x.shape[-2:]
x = self.to_batch(x).float()
x = self.to_device(x, self.device)
with torch.inference_mode():
out = self.model(x)
out = self.post_forward(out)
out = self.postprocess_model_output(
out, raw_out=raw_out, out_shape=out_shape)
return out
def predict_onnx(
self,
x: torch.Tensor,
raw_out: bool = False,
out_shape: Optional[Tuple[int, int]] = None) -> torch.Tensor:
if out_shape is None:
out_shape = x.shape[-2:]
x = self.to_batch(x).float()
out = self.model(x)
out = self.post_forward(out)
out = self.postprocess_model_output(
out, raw_out=raw_out, out_shape=out_shape)
return out
def postprocess_model_output(self, out: torch.Tensor, raw_out: bool,
out_shape: Tuple[int, int]):
out = out.softmax(dim=1)
# ensure correct output shape
if out.shape[-2:] != out_shape:
out = F.interpolate(
out, size=out_shape, mode='bilinear', align_corners=False)
if not raw_out:
out = self.prob_to_pred(out)
out = self.to_device(out, 'cpu')
return out
def prob_to_pred(self, x):
return x.argmax(1)
def export_to_onnx(self,
path: str,
model: Optional['nn.Module'] = None,
sample_input: Optional[torch.Tensor] = None,
**kwargs) -> None:
args = dict(
input_names=['x'],
output_names=['out'],
dynamic_axes={
'x': {
0: 'batch_size',
2: 'height',
3: 'width',
},
'out': {
0: 'batch_size',
2: 'height',
3: 'width',
},
},
)
args.update(kwargs)
return super().export_to_onnx(path, model, sample_input, **args)