-
Notifications
You must be signed in to change notification settings - Fork 1k
/
base_edit_model.py
185 lines (145 loc) · 6.73 KB
/
base_edit_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import torch
from mmengine.model import BaseModel
from mmedit.registry import MODELS
from mmedit.structures import EditDataSample, PixelData
@MODELS.register_module()
class BaseEditModel(BaseModel):
"""Base model for image and video editing.
It must contain a generator that takes frames as inputs and outputs an
interpolated frame. It also has a pixel-wise loss for training.
Args:
generator (dict): Config for the generator structure.
pixel_loss (dict): Config for pixel-wise loss.
train_cfg (dict): Config for training. Default: None.
test_cfg (dict): Config for testing. Default: None.
init_cfg (dict, optional): The weight initialized config for
:class:`BaseModule`.
data_preprocessor (dict, optional): The pre-process config of
:class:`BaseDataPreprocessor`.
Attributes:
init_cfg (dict, optional): Initialization config dict.
data_preprocessor (:obj:`BaseDataPreprocessor`): Used for
pre-processing data sampled by dataloader to the format accepted by
:meth:`forward`. Default: None.
"""
def __init__(self,
generator,
pixel_loss,
train_cfg=None,
test_cfg=None,
init_cfg=None,
data_preprocessor=None):
super().__init__(
init_cfg=init_cfg, data_preprocessor=data_preprocessor)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
# generator
self.generator = MODELS.build(generator)
# loss
self.pixel_loss = MODELS.build(pixel_loss)
def forward(self,
inputs: torch.Tensor,
data_samples: Optional[List[EditDataSample]] = None,
mode: str = 'tensor',
**kwargs):
"""Returns losses or predictions of training, validation, testing, and
simple inference process.
``forward`` method of BaseModel is an abstract method, its subclasses
must implement this method.
Accepts ``inputs`` and ``data_samples`` processed by
:attr:`data_preprocessor`, and returns results according to mode
arguments.
During non-distributed training, validation, and testing process,
``forward`` will be called by ``BaseModel.train_step``,
``BaseModel.val_step`` and ``BaseModel.val_step`` directly.
During distributed data parallel training process,
``MMSeparateDistributedDataParallel.train_step`` will first call
``DistributedDataParallel.forward`` to enable automatic
gradient synchronization, and then call ``forward`` to get training
loss.
Args:
inputs (torch.Tensor): batch input tensor collated by
:attr:`data_preprocessor`.
data_samples (List[BaseDataElement], optional):
data samples collated by :attr:`data_preprocessor`.
mode (str): mode should be one of ``loss``, ``predict`` and
``tensor``. Default: 'tensor'.
- ``loss``: Called by ``train_step`` and return loss ``dict``
used for logging
- ``predict``: Called by ``val_step`` and ``test_step``
and return list of ``BaseDataElement`` results used for
computing metric.
- ``tensor``: Called by custom use to get ``Tensor`` type
results.
Returns:
ForwardResults:
- If ``mode == loss``, return a ``dict`` of loss tensor used
for backward and logging.
- If ``mode == predict``, return a ``list`` of
:obj:`BaseDataElement` for computing metric
and getting inference result.
- If ``mode == tensor``, return a tensor or ``tuple`` of tensor
or ``dict`` or tensor for custom use.
"""
if mode == 'tensor':
return self.forward_tensor(inputs, data_samples, **kwargs)
elif mode == 'predict':
predictions = self.forward_inference(inputs, data_samples,
**kwargs)
predictions = self.convert_to_datasample(data_samples, predictions)
return predictions
elif mode == 'loss':
return self.forward_train(inputs, data_samples, **kwargs)
def convert_to_datasample(self, inputs, data_samples):
for data_sample, output in zip(inputs, data_samples):
data_sample.output = output
return inputs
def forward_tensor(self, inputs, data_samples=None, **kwargs):
"""Forward tensor. Returns result of simple forward.
Args:
inputs (torch.Tensor): batch input tensor collated by
:attr:`data_preprocessor`.
data_samples (List[BaseDataElement], optional):
data samples collated by :attr:`data_preprocessor`.
Returns:
Tensor: result of simple forward.
"""
feats = self.generator(inputs, **kwargs)
return feats
def forward_inference(self, inputs, data_samples=None, **kwargs):
"""Forward inference. Returns predictions of validation, testing, and
simple inference.
Args:
inputs (torch.Tensor): batch input tensor collated by
:attr:`data_preprocessor`.
data_samples (List[BaseDataElement], optional):
data samples collated by :attr:`data_preprocessor`.
Returns:
List[EditDataSample]: predictions.
"""
feats = self.forward_tensor(inputs, data_samples, **kwargs)
feats = self.data_preprocessor.destructor(feats)
predictions = []
for idx in range(feats.shape[0]):
predictions.append(
EditDataSample(
pred_img=PixelData(data=feats[idx].to('cpu')),
metainfo=data_samples[idx].metainfo))
return predictions
def forward_train(self, inputs, data_samples=None, **kwargs):
"""Forward training. Returns dict of losses of training.
Args:
inputs (torch.Tensor): batch input tensor collated by
:attr:`data_preprocessor`.
data_samples (List[BaseDataElement], optional):
data samples collated by :attr:`data_preprocessor`.
Returns:
dict: Dict of losses.
"""
feats = self.forward_tensor(inputs, data_samples, **kwargs)
gt_imgs = [data_sample.gt_img.data for data_sample in data_samples]
batch_gt_data = torch.stack(gt_imgs)
loss = self.pixel_loss(feats, batch_gt_data)
return dict(loss=loss)