-
Notifications
You must be signed in to change notification settings - Fork 25.3k
/
modeling_oneformer.py
3251 lines (2778 loc) 路 140 KB
/
modeling_oneformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2022 SHI Labs and The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch OneFormer model."""
import copy
import math
import warnings
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch import Tensor, nn
from torch.cuda.amp import autocast
from ... import AutoBackbone
from ...activations import ACT2FN
from ...modeling_outputs import BaseModelOutput
from ...modeling_utils import PreTrainedModel
from ...utils import (
ModelOutput,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_scipy_available,
logging,
replace_return_docstrings,
requires_backends,
)
from .configuration_oneformer import OneFormerConfig
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "OneFormerConfig"
_CHECKPOINT_FOR_DOC = "shi-labs/oneformer_ade20k_swin_tiny"
ONEFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = [
"shi-labs/oneformer_ade20k_swin_tiny",
# See all OneFormer models at https://huggingface.co/models?filter=oneformer
]
if is_scipy_available():
from scipy.optimize import linear_sum_assignment
def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
# Copied from transformers.models.deformable_detr.modeling_deformable_detr.multi_scale_deformable_attention
def multi_scale_deformable_attention(
value: Tensor, value_spatial_shapes: Tensor, sampling_locations: Tensor, attention_weights: Tensor
) -> Tensor:
batch_size, _, num_heads, hidden_dim = value.shape
_, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
value_list = value.split([height.item() * width.item() for height, width in value_spatial_shapes], dim=1)
sampling_grids = 2 * sampling_locations - 1
sampling_value_list = []
for level_id, (height, width) in enumerate(value_spatial_shapes):
# batch_size, height*width, num_heads, hidden_dim
# -> batch_size, height*width, num_heads*hidden_dim
# -> batch_size, num_heads*hidden_dim, height*width
# -> batch_size*num_heads, hidden_dim, height, width
value_l_ = (
value_list[level_id].flatten(2).transpose(1, 2).reshape(batch_size * num_heads, hidden_dim, height, width)
)
# batch_size, num_queries, num_heads, num_points, 2
# -> batch_size, num_heads, num_queries, num_points, 2
# -> batch_size*num_heads, num_queries, num_points, 2
sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
# batch_size*num_heads, hidden_dim, num_queries, num_points
sampling_value_l_ = nn.functional.grid_sample(
value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
)
sampling_value_list.append(sampling_value_l_)
# (batch_size, num_queries, num_heads, num_levels, num_points)
# -> (batch_size, num_heads, num_queries, num_levels, num_points)
# -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
attention_weights = attention_weights.transpose(1, 2).reshape(
batch_size * num_heads, 1, num_queries, num_levels * num_points
)
output = (
(torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
.sum(-1)
.view(batch_size, num_heads * hidden_dim, num_queries)
)
return output.transpose(1, 2).contiguous()
# Copied from transformers.models.maskformer.modeling_maskformer.dice_loss
def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor:
r"""
Compute the DICE loss, similar to generalized IOU for masks as follows:
$$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$
In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow
$$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$
Args:
inputs (`torch.Tensor`):
A tensor representing a mask.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
num_masks (`int`):
The number of masks present in the current batch, used for normalization.
Returns:
`torch.Tensor`: The computed loss.
"""
probs = inputs.sigmoid().flatten(1)
numerator = 2 * (probs * labels).sum(-1)
denominator = probs.sum(-1) + labels.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
loss = loss.sum() / num_masks
return loss
# Copied from transformers.models.mask2former.modeling_mask2former.sigmoid_cross_entropy_loss
def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor:
r"""
Args:
inputs (`torch.Tensor`):
A float tensor of arbitrary shape.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
loss (`torch.Tensor`): The computed loss.
"""
criterion = nn.BCEWithLogitsLoss(reduction="none")
cross_entropy_loss = criterion(inputs, labels)
loss = cross_entropy_loss.mean(1).sum() / num_masks
return loss
# Copied from transformers.models.maskformer.modeling_maskformer.pair_wise_dice_loss
def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
"""
A pair wise version of the dice loss, see `dice_loss` for usage.
Args:
inputs (`torch.Tensor`):
A tensor representing a mask
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
`torch.Tensor`: The computed loss between each pairs.
"""
inputs = inputs.sigmoid().flatten(1)
numerator = 2 * torch.matmul(inputs, labels.T)
# using broadcasting to get a [num_queries, NUM_CLASSES] matrix
denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
loss = 1 - (numerator + 1) / (denominator + 1)
return loss
# Copied from transformers.models.mask2former.modeling_mask2former.pair_wise_sigmoid_cross_entropy_loss
def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
r"""
A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage.
Args:
inputs (`torch.Tensor`):
A tensor representing a mask.
labels (`torch.Tensor`):
A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
(0 for the negative class and 1 for the positive class).
Returns:
loss (`torch.Tensor`): The computed loss between each pairs.
"""
height_and_width = inputs.shape[1]
criterion = nn.BCEWithLogitsLoss(reduction="none")
cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
loss_pos = torch.matmul(cross_entropy_loss_pos, labels.T)
loss_neg = torch.matmul(cross_entropy_loss_neg, (1 - labels).T)
loss = loss_pos + loss_neg
loss = loss / height_and_width
return loss
# Copied from transformers.models.mask2former.modeling_mask2former.sample_point
def sample_point(
input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs
) -> torch.Tensor:
"""
A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors.
Args:
input_features (`torch.Tensor` of shape (batch_size, channels, height, width)):
A tensor that contains features map on a height * width grid
point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,:
2)):
A tensor that contains [0, 1] * [0, 1] normalized point coordinates
add_dim (`bool`):
boolean value to keep track of added dimension
Returns:
point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels,
height_grid, width_grid):
A tensor that contains features for points in `point_coordinates`.
"""
if point_coordinates.dim() == 3:
add_dim = True
point_coordinates = point_coordinates.unsqueeze(2)
# use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation
point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs)
if add_dim:
point_features = point_features.squeeze(3)
return point_features
# Refactored from https://github.com/SHI-Labs/OneFormer/blob/33ebb56ed34f970a30ae103e786c0cb64c653d9a/oneformer/modeling/matcher.py#L93
class OneFormerHungarianMatcher(nn.Module):
def __init__(
self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544
):
"""This class computes an assignment between the labels and the predictions of the network.
For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more
predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are
un-matched (and thus treated as non-objects).
Params:
cost_class (float, *optional*, defaults to 1.0):
This is the relative weight of the classification error in the matching cost.
cost_mask (float, *optional*, defaults to 1.0):
This is the relative weight of the sigmoid ce loss of the binary mask in the matching cost.
cost_dice (float, *optional*, defaults to 1.0):
This is the relative weight of the dice loss of the binary mask in the matching cost
num_points (int, *optional*, defaults to 12544):
Number of points to be sampled for dice and mask loss matching cost.
"""
super().__init__()
if cost_class == 0 and cost_mask == 0 and cost_dice == 0:
raise ValueError("All costs cant be 0")
self.cost_class = cost_class
self.cost_mask = cost_mask
self.cost_dice = cost_dice
self.num_points = num_points
@torch.no_grad()
def forward(self, masks_queries_logits, class_queries_logits, mask_labels, class_labels) -> List[Tuple[Tensor]]:
"""Performs the matching
Params:
masks_queries_logits (`torch.Tensor`):
A tensor` of dim `batch_size, num_queries, num_labels` with the
classification logits.
class_queries_logits (`torch.Tensor`):
A tensor` of dim `batch_size, num_queries, height, width` with the
predicted masks.
class_labels (`torch.Tensor`):
A tensor` of dim `num_target_boxes` (where num_target_boxes is the number
of ground-truth objects in the target) containing the class labels.
mask_labels (`torch.Tensor`):
A tensor` of dim `num_target_boxes, height, width` containing the target
masks.
Returns:
`List[Tuple[Tensor]]`: A list of size batch_size, containing tuples of (index_i, index_j) where:
- index_i is the indices of the selected predictions (in order)
- index_j is the indices of the corresponding selected labels (in order)
For each batch element, it holds:
len(index_i) = len(index_j) = min(num_queries, num_targets).
"""
indices: List[Tuple[np.array]] = []
num_queries = class_queries_logits.shape[1]
preds_masks = masks_queries_logits
preds_probs = class_queries_logits
# iterate through batch size
for pred_probs, pred_mask, target_mask, labels in zip(preds_probs, preds_masks, mask_labels, class_labels):
pred_probs = pred_probs.softmax(-1)
# Compute the classification cost. Contrary to the loss, we don't use the NLL,
# but approximate it in 1 - proba[target class].
# The 1 is a constant that doesn't change the matching, it can be ommitted.
cost_class = -pred_probs[:, labels]
pred_mask = pred_mask[:, None]
target_mask = target_mask[:, None].to(pred_mask.device)
# all masks share the same set of points for efficient matching!
point_coords = torch.rand(1, self.num_points, 2, device=pred_mask.device)
# get ground truth labels
target_mask = sample_point(
target_mask,
point_coords.repeat(target_mask.shape[0], 1, 1),
align_corners=False,
).squeeze(1)
pred_mask = sample_point(
pred_mask,
point_coords.repeat(pred_mask.shape[0], 1, 1),
align_corners=False,
).squeeze(1)
with autocast(enabled=False):
pred_mask = pred_mask.float()
target_mask = target_mask.float()
# compute the sigmoid ce loss
cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask)
# Compute the dice loss
cost_dice = pair_wise_dice_loss(pred_mask, target_mask)
# final cost matrix
cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
cost_matrix = cost_matrix.reshape(num_queries, -1).cpu()
# do the assigmented using the hungarian algorithm in scipy
assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.cpu())
indices.append(assigned_indices)
# It could be stacked in one tensor
matched_indices = [
(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices
]
return matched_indices
class OneFormerLoss(nn.Module):
def __init__(
self,
num_classes: int,
matcher: OneFormerHungarianMatcher,
weight_dict: Dict[str, float],
eos_coef: float,
num_points: int,
oversample_ratio: float,
importance_sample_ratio: float,
contrastive_temperature: float = None,
):
"""
This class computes the losses using the class predictions, mask predictions and the contrastive queries.
Oneformer calculates the classification CE loss on the class predictions. Mask predictions are used for
calculating the binary CE loss and dice loss. The contrastive queries are used for calculating the contrastive
loss.
Args:
num_labels (`int`):
The number of classes.
matcher (`OneFormerHungarianMatcher`):
A torch module that computes the assigments between the predictions and labels.
weight_dict (`Dict[str, float]`):
A dictionary of weights to be applied to the different losses.
eos_coef (`float`):
Weight to apply to the null class.
num_points (`int`):
Number of points to be sampled for dice and mask loss calculations.
oversample_ratio (`float`):
Required for pointwise loss calculation.
importance_sample_ratio (`float`):
Required for pointwise loss calculation.
contrastive_temperature (`float`):
Temperature for scaling the contrastive logits.
"""
requires_backends(self, ["scipy"])
super().__init__()
self.num_classes = num_classes
self.matcher = matcher
self.weight_dict = weight_dict
self.eos_coef = eos_coef
empty_weight = torch.ones(self.num_classes + 1)
empty_weight[-1] = self.eos_coef
self.register_buffer("empty_weight", empty_weight)
# pointwise mask loss parameters
self.num_points = num_points
self.oversample_ratio = oversample_ratio
self.importance_sample_ratio = importance_sample_ratio
self.contrastive_temperature = contrastive_temperature
if self.contrastive_temperature is not None:
self.logit_scale = nn.Parameter(torch.tensor(np.log(1 / contrastive_temperature)))
def _max_by_axis(self, the_list: List[List[int]]) -> List[int]:
maxes = the_list[0]
for sublist in the_list[1:]:
for index, item in enumerate(sublist):
maxes[index] = max(maxes[index], item)
return maxes
def _pad_images_to_max_in_batch(self, tensors: List[Tensor]) -> Tuple[Tensor, Tensor]:
# get the maximum size in the batch
max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
batch_size = len(tensors)
# compute finel size
batch_shape = [batch_size] + max_size
b, _, h, w = batch_shape
# get metadata
dtype = tensors[0].dtype
device = tensors[0].device
padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
padding_masks = torch.ones((b, h, w), dtype=torch.bool, device=device)
# pad the tensors to the size of the biggest one
for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
return padded_tensors, padding_masks
def loss_contrastive(self, contrastive_queries_logits: Tensor, text_queries: Tensor):
"""Compute the query-text contrastive loss.
Args:
contrastive_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, hidden_dim`
text_queries (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, hidden_dim`
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
- **loss_contrastive** -- The query-text contrastive loss computed using task-guided queries
and text queries derived from input text list.
"""
image_queries = contrastive_queries_logits.float()
# [batch_size, hidden_dim]
image_queries = nn.functional.normalize(image_queries.flatten(1), dim=-1)
text_queries = nn.functional.normalize(text_queries.flatten(1), dim=-1)
logit_scale = torch.clamp(self.logit_scale.exp(), max=100)
logits_per_text = torch.matmul(text_queries, image_queries.t()) * logit_scale
logits_per_img = logits_per_text.t()
loss_img = nn.functional.cross_entropy(
logits_per_img, torch.arange(len(logits_per_img), device=logits_per_text.device)
)
loss_text = nn.functional.cross_entropy(
logits_per_text, torch.arange(len(logits_per_text), device=logits_per_text.device)
)
loss_contrastive = loss_img + loss_text
losses = {"loss_contrastive": loss_contrastive}
return losses
def loss_labels(
self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array]
) -> Dict[str, Tensor]:
"""Compute the losses related to the labels using cross entropy.
Args:
class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_labels`
class_labels (`List[torch.Tensor]`):
List of class labels of shape `(labels)`.
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
- **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
"""
pred_logits = class_queries_logits
batch_size, num_queries, _ = pred_logits.shape
criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
idx = self._get_predictions_permutation_indices(indices)
# shape = (batch_size, num_queries)
target_classes_o = torch.cat([target[j] for target, (_, j) in zip(class_labels, indices)])
# shape = (batch_size, num_queries)
target_classes = torch.full(
(batch_size, num_queries), fill_value=self.num_classes, dtype=torch.int64, device=pred_logits.device
)
target_classes[idx] = target_classes_o
# permute pred_logits (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries)
pred_logits_transposed = pred_logits.transpose(1, 2)
loss_ce = criterion(pred_logits_transposed, target_classes)
losses = {"loss_cross_entropy": loss_ce}
return losses
def loss_masks(
self, masks_queries_logits: Tensor, mask_labels: List[Tensor], indices: Tuple[np.array], num_masks: int
) -> Dict[str, Tensor]:
"""Compute the losses related to the masks using focal and dice loss.
Args:
masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width`
mask_labels (`torch.Tensor`):
List of mask labels of shape `(labels, height, width)`.
indices (`Tuple[np.array])`:
The indices computed by the Hungarian matcher.
num_masks (`int)`:
The number of masks, used for normalization.
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys:
- **loss_mask** -- The loss computed using sigmoid ce loss on the predicted and ground truth masks.
- **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth
masks.
"""
src_idx = self._get_predictions_permutation_indices(indices)
tgt_idx = self._get_targets_permutation_indices(indices)
# shape (batch_size * num_queries, height, width)
pred_masks = masks_queries_logits[src_idx]
# shape (batch_size, num_queries, height, width)
# pad all and stack the targets to the num_labels dimension
# upsample predictions to the target size, we have to add one dim to use interpolate
target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
target_masks = target_masks[tgt_idx]
pred_masks = pred_masks[:, None]
target_masks = target_masks[:, None]
with torch.no_grad():
# sample point_coords
point_coords = self.sample_points_using_uncertainty(
pred_masks,
self.calculate_uncertainty,
self.num_points,
self.oversample_ratio,
self.importance_sample_ratio,
)
# get ground-truth labels
point_labels = sample_point(target_masks, point_coords, align_corners=False).squeeze(1)
point_logits = sample_point(pred_masks, point_coords, align_corners=False).squeeze(1)
losses = {
"loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks),
"loss_dice": dice_loss(point_logits, point_labels, num_masks),
}
del pred_masks
del target_masks
return losses
# Copied from transformers.models.mask2former.modeling_mask2former.Mask2FormerLoss.calculate_uncertainty
def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor:
"""
In Mask2Former paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits'
for the foreground class in `classes`.
Args:
logits (`torch.Tensor`):
A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is:
the number of foreground classes. The values are logits.
Returns:
scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most
uncertain locations having the highest uncertainty score.
"""
uncertainty_scores = -(torch.abs(logits))
return uncertainty_scores
# Copied from transformers.models.mask2former.modeling_mask2former.Mask2FormerLoss.sample_points_using_uncertainty
def sample_points_using_uncertainty(
self,
logits: torch.Tensor,
uncertainty_function,
num_points: int,
oversample_ratio: int,
importance_sample_ratio: float,
) -> torch.Tensor:
"""
This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The
uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit
prediction as input.
Args:
logits (`float`):
Logit predictions for P points.
uncertainty_function:
A function that takes logit predictions for P points and returns their uncertainties.
num_points (`int`):
The number of points P to sample.
oversample_ratio (`int`):
Oversampling parameter.
importance_sample_ratio (`float`):
Ratio of points that are sampled via importance sampling.
Returns:
point_coordinates (`torch.Tensor`):
Coordinates for P sampled points.
"""
num_boxes = logits.shape[0]
num_points_sampled = int(num_points * oversample_ratio)
# Get random point coordinates
point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
# Get sampled prediction value for the point coordinates
point_logits = sample_point(logits, point_coordinates, align_corners=False)
# Calculate the uncertainties based on the sampled prediction values of the points
point_uncertainties = uncertainty_function(point_logits)
num_uncertain_points = int(importance_sample_ratio * num_points)
num_random_points = num_points - num_uncertain_points
idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
idx += shift[:, None]
point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
if num_random_points > 0:
point_coordinates = torch.cat(
[point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],
dim=1,
)
return point_coordinates
def _get_predictions_permutation_indices(self, indices):
# permute predictions following indices
batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
predictions_indices = torch.cat([src for (src, _) in indices])
return batch_indices, predictions_indices
def _get_targets_permutation_indices(self, indices):
# permute labels following indices
batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
target_indices = torch.cat([tgt for (_, tgt) in indices])
return batch_indices, target_indices
def forward(
self,
masks_queries_logits: Tensor,
class_queries_logits: Tensor,
contrastive_queries_logits: Tensor,
mask_labels: List[Tensor],
class_labels: List[Tensor],
text_queries: Tensor,
auxiliary_predictions: Optional[Dict[str, Tensor]] = None,
calculate_contrastive_loss: bool = True,
) -> Dict[str, Tensor]:
"""
This performs the loss computation.
Args:
masks_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, height, width`
class_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, num_labels`
contrastive_queries_logits (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, hidden_dim`
mask_labels (`torch.Tensor`):
List of mask labels of shape `(labels, height, width)`.
class_labels (`List[torch.Tensor]`):
List of class labels of shape `(labels)`.
text_queries (`torch.Tensor`):
A tensor of shape `batch_size, num_queries, hidden_dim`
auxiliary_predictions (`Dict[str, torch.Tensor]`, *optional*):
if `use_auxiliary_loss` was set to `true` in [`OneFormerConfig`], then it contains the logits from the
inner layers of the Detr's Decoder.
calculate_contrastive_loss (`bool`, *optional*, defaults to `True`):
Whether or not to calculate the contrastive loss.
Returns:
`Dict[str, Tensor]`: A dict of `torch.Tensor` containing two keys:
- **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
- **loss_mask** -- The loss computed using sigmoid ce loss on the predicted and ground truth masks.
- **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth
masks.
- **loss_contrastive** -- The query-text contrstive loss computed using object and text queries.
if `use_auxiliary_loss` was set to `true` in [`OneFormerConfig`], the dictionary contains addional losses
for each auxiliary predictions.
"""
# retrieve the matching between the outputs of the last layer and the labels
indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
# compute the average number of target masks for normalization purposes
num_masks = self.get_num_masks(class_labels, device=class_labels[0].device)
# get all the losses
losses: Dict[str, Tensor] = {
**self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
**self.loss_labels(class_queries_logits, class_labels, indices),
}
if calculate_contrastive_loss:
losses = {**losses, **self.loss_contrastive(contrastive_queries_logits, text_queries)}
# in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
if auxiliary_predictions is not None:
for idx, aux_outputs in enumerate(auxiliary_predictions):
masks_queries_logits = aux_outputs["masks_queries_logits"]
class_queries_logits = aux_outputs["class_queries_logits"]
loss_dict = self.forward(
masks_queries_logits,
class_queries_logits,
None,
mask_labels,
class_labels,
None,
calculate_contrastive_loss=False,
)
loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()}
losses.update(loss_dict)
return losses
def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
"""
Computes the average number of target masks across the batch, for normalization purposes.
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)
return num_masks_pt
@dataclass
class OneFormerTransformerDecoderOutput(BaseModelOutput):
"""
Base class for outputs of the Transformer decoder. This class adds attributes for class predictions, mask
predictions and contrastive logits to BaseModelOutputWithCrossAttentions.
Args:
object_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`):
Queries representation for the region proposals.
contrastive_logits (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`):
Queries representation for the contrastive loss.
prediction_masks (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`):
Mask predictions from last layer of the transformer decoder.
prediction_class (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes+1)`):
Class predictions from last layer of the transformer decoder.
auxiliary_predictions (Tuple of Dict of `str, torch.FloatTensor`, *optional*):
Tuple of class and mask predictions from each layer of the transformer decoder.
"""
object_queries: torch.FloatTensor = None
contrastive_logits: Optional[torch.FloatTensor] = None
prediction_masks: torch.FloatTensor = None
prediction_class: torch.FloatTensor = None
auxiliary_predictions: Optional[Tuple[Dict[str, torch.FloatTensor]]] = None
@dataclass
# Copied from transformers.models.mask2former.modeling_mask2former.Mask2FormerPixelDecoderOutput with Mask2->One
class OneFormerPixelDecoderOutput(ModelOutput):
"""
OneFormer's pixel decoder module output, practically a Multi-Scale Deformable Attention based decoder. It returns
the mask features and the multiscale features.
Args:
multi_scale_features (`tuple(torch.FloatTensor)`):
Tuple of multi-scale features of scales [1/8, 1/16, 1/32] and shape `(batch_size, num_channels, height,
width)`from the Multi-Scale Deformable Attenntion based Pixel Decoder.
mask_features (`torch.FloatTensor`):
Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel Decoder
Layer.
attentions (`tuple(torch.FloatTensor)`, *optional*):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Attentions weights from pixel decoder. Returned when `output_attentions=True` is passed
or when `config.output_attentions=True`
"""
multi_scale_features: Tuple[torch.FloatTensor] = None
mask_features: torch.FloatTensor = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class OneFormerPixelLevelModuleOutput(ModelOutput):
"""
OneFormer's pixel level module output. It returns both the last and (optionally) the hidden states from the
`encoder` and `decoder`. By default, the `encoder` is a Swin/Dinat Backbone and the `decoder` is a Multi-Scale
Deformable Attention based decoder.
Args:
encoder_features (List of `(torch.FloatTensor)`):
List of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
called feature maps) of the model at the output of each stage.
decoder_features (List of `(torch.FloatTensor)`):
List of `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`. Hidden-states (also
called feature maps) of the model at the output of each stage.
decoder_last_feature (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)):
1/4 scale features from the last Pixel Decoder Layer.
"""
encoder_features: List[torch.FloatTensor] = None
decoder_features: List[torch.FloatTensor] = None
decoder_last_feature: torch.FloatTensor = None
@dataclass
class OneFormerModelOutput(ModelOutput):
"""
Class for outputs of [`OneFormerModel`]. This class returns all the needed hidden states to compute the logits.
Args:
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder
model at the output of each stage.
pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel
decoder model at the output of each stage.
transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the
transformer decoder at the output of each stage.
transformer_decoder_object_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`)
Output object queries from the last layer in the transformer decoder.
transformer_decoder_contrastive_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`)
Contrastive queries from the transformer decoder.
transformer_decoder_mask_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`)
Mask Predictions from the last layer in the transformer decoder.
transformer_decoder_class_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes+1)`):
Class Predictions from the last layer in the transformer decoder.
transformer_decoder_auxiliary_predictions (Tuple of Dict of `str, torch.FloatTensor`, *optional*):
Tuple of class and mask predictions from each layer of the transformer decoder.
text_queries (`torch.FloatTensor`, *optional* of shape `(batch_size, num_queries, hidden_dim)`)
Text queries derived from the input text list used for calculating contrastive loss during training.
task_token (`torch.FloatTensor` of shape `(batch_size, hidden_dim)`)
1D task token to condition the queries.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
"""
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pixel_decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None
transformer_decoder_object_queries: torch.FloatTensor = None
transformer_decoder_contrastive_queries: Optional[torch.FloatTensor] = None
transformer_decoder_mask_predictions: torch.FloatTensor = None
transformer_decoder_class_predictions: torch.FloatTensor = None
transformer_decoder_auxiliary_predictions: Optional[Tuple[Dict[str, torch.FloatTensor]]] = None
text_queries: Optional[torch.FloatTensor] = None
task_token: torch.FloatTensor = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
@dataclass
class OneFormerForUniversalSegmentationOutput(ModelOutput):
"""
Class for outputs of [`OneFormerForUniversalSegmentationOutput`].
This output can be directly passed to [`~OneFormerImageProcessor.post_process_semantic_segmentation`] or
[`~OneFormerImageProcessor.post_process_instance_segmentation`] or
[`~OneFormerImageProcessor.post_process_panoptic_segmentation`] depending on the task. Please, see
[`~OneFormerImageProcessor] for details regarding usage.
Args:
loss (`torch.Tensor`, *optional*):
The computed loss, returned when labels are present.
class_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
query. Note the `+ 1` is needed because we incorporate the null class.
masks_queries_logits (`torch.FloatTensor`):
A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
query.
auxiliary_predictions (List of Dict of `str, torch.FloatTensor`, *optional*):
List of class and mask predictions from each layer of the transformer decoder.
encoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the encoder
model at the output of each stage.
pixel_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, num_channels, height, width)`. Hidden-states (also called feature maps) of the pixel
decoder model at the output of each stage.
transformer_decoder_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
shape `(batch_size, sequence_length, hidden_size)`. Hidden-states (also called feature maps) of the
transformer decoder at the output of each stage.
transformer_decoder_object_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`)
Output object queries from the last layer in the transformer decoder.
transformer_decoder_contrastive_queries (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_dim)`)
Contrastive queries from the transformer decoder.
transformer_decoder_mask_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, height, width)`)
Mask Predictions from the last layer in the transformer decoder.
transformer_decoder_class_predictions (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes+1)`):
Class Predictions from the last layer in the transformer decoder.
transformer_decoder_auxiliary_predictions (List of Dict of `str, torch.FloatTensor`, *optional*):
List of class and mask predictions from each layer of the transformer decoder.
text_queries (`torch.FloatTensor`, *optional* of shape `(batch_size, num_queries, hidden_dim)`)
Text queries derived from the input text list used for calculating contrastive loss during training.
task_token (`torch.FloatTensor` of shape `(batch_size, hidden_dim)`)
1D task token to condition the queries.
attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`. Self and Cross Attentions weights from transformer decoder.
"""
loss: Optional[torch.FloatTensor] = None
class_queries_logits: torch.FloatTensor = None
masks_queries_logits: torch.FloatTensor = None
auxiliary_predictions: List[Dict[str, torch.FloatTensor]] = None
encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
pixel_decoder_hidden_states: Optional[List[torch.FloatTensor]] = None
transformer_decoder_hidden_states: Optional[torch.FloatTensor] = None
transformer_decoder_object_queries: torch.FloatTensor = None
transformer_decoder_contrastive_queries: Optional[torch.FloatTensor] = None
transformer_decoder_mask_predictions: torch.FloatTensor = None
transformer_decoder_class_predictions: torch.FloatTensor = None
transformer_decoder_auxiliary_predictions: Optional[List[Dict[str, torch.FloatTensor]]] = None
text_queries: Optional[torch.FloatTensor] = None
task_token: torch.FloatTensor = None
attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
# Modified from transformers.models.deformable_detr.modeling_deformable_detr.DeformableDetrFrozenBatchNorm2d with DeformableDetr->OneFormerPixelDecoder
class OneFormerPixelDecoderFrozenBatchNorm2d(nn.Module):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
torchvision.models.resnet[18,34,50,101] produce nans.
"""
def __init__(self, n):
super().__init__()
self.register_buffer("weight", torch.ones(n))
self.register_buffer("bias", torch.zeros(n))
self.register_buffer("running_mean", torch.zeros(n))
self.register_buffer("running_var", torch.ones(n))
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x):
weight = self.weight.reshape(1, -1, 1, 1)
bias = self.bias.reshape(1, -1, 1, 1)
running_var = self.running_var.reshape(1, -1, 1, 1)
running_mean = self.running_mean.reshape(1, -1, 1, 1)
epsilon = 1e-5
scale = weight * (running_var + epsilon).rsqrt()
bias = bias - running_mean * scale
return x * scale + bias
# Modified from transformers.models.detr.modeling_deformable_detr.DeformableDetrMultiscaleDeformableAttention with DeformableDetr->OneFormerPixelDecoderEncoder
class OneFormerPixelDecoderEncoderMultiscaleDeformableAttention(nn.Module):
"""
Multiscale deformable attention as proposed in Deformable DETR.
"""
def __init__(self, embed_dim: int, num_heads: int, n_levels: int, n_points: int):
super().__init__()
if embed_dim % num_heads != 0:
raise ValueError(
f"embed_dim (d_model) must be divisible by num_heads, but got {embed_dim} and {num_heads}"
)
dim_per_head = embed_dim // num_heads
# check if dim_per_head is power of 2
if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
warnings.warn(
"You'd better set embed_dim (d_model) in DeformableDetrMultiscaleDeformableAttention to make the"
" dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
" implementation."
)
self.im2col_step = 128
self.d_model = embed_dim
self.n_levels = n_levels
self.n_heads = num_heads
self.n_points = n_points
self.sampling_offsets = nn.Linear(embed_dim, num_heads * n_levels * n_points * 2)
self.attention_weights = nn.Linear(embed_dim, num_heads * n_levels * n_points)
self.value_proj = nn.Linear(embed_dim, embed_dim)
self.output_proj = nn.Linear(embed_dim, embed_dim)
def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Optional[Tensor]):
return tensor if position_embeddings is None else tensor + position_embeddings
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None,