-
Notifications
You must be signed in to change notification settings - Fork 0
/
special_tokens.py
120 lines (92 loc) · 4.04 KB
/
special_tokens.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import Literal, Optional, Union, overload
import torch
from loguru import logger
from pydantic import BaseModel, Field
from emma_experience_hub.datamodels import EmmaExtractedFeatures
from emma_experience_hub.datamodels.simbot.payloads import SimBotObjectMaskType
from emma_experience_hub.functions.simbot.masks import compress_segmentation_mask
class SimBotSceneObjectTokens(BaseModel):
"""Token IDs when we find an object in the scene."""
frame_index: int = Field(default=1, gt=0)
object_index: Optional[int] = Field(default=None, gt=0)
def extract_index_from_special_token(token: str) -> int:
"""Extract the token index from a special token."""
return int(token.strip().split("_")[-1].replace(">", ""))
@overload
def get_mask_from_special_tokens(
frame_index: int,
object_index: int,
extracted_features: list[EmmaExtractedFeatures],
return_coords: Literal[False] = False,
) -> SimBotObjectMaskType:
... # noqa: WPS428
@overload
def get_mask_from_special_tokens(
frame_index: int,
object_index: int,
extracted_features: list[EmmaExtractedFeatures],
return_coords: Literal[True],
) -> tuple[SimBotObjectMaskType, tuple[float, ...]]:
... # noqa: WPS428
def get_mask_from_special_tokens(
frame_index: int,
object_index: int,
extracted_features: list[EmmaExtractedFeatures],
return_coords: bool = False,
) -> Union[SimBotObjectMaskType, tuple[SimBotObjectMaskType, tuple[float, ...]]]:
"""Get the object mask from the visual token."""
# Get the bbox coordinates for the correct frame index
object_coordinates_bbox = extracted_features[frame_index - 1].bbox_coords
# Get the coordinates for the specified object
(x_min, y_min, x_max, y_max) = object_coordinates_bbox[object_index - 1].tolist()
# Create an empty mask for the object
mask = torch.zeros(
(
extracted_features[frame_index - 1].width,
extracted_features[frame_index - 1].height,
)
)
# Populate the bbox region in the mask
mask[int(y_min) : int(y_max) + 1, int(x_min) : int(x_max) + 1] = 1 # noqa: WPS221
compressed_mask = compress_segmentation_mask(mask)
if return_coords:
return compressed_mask, (x_min, y_min, x_max, y_max)
return compressed_mask
def get_class_name_from_special_tokens(
frame_index: int,
object_index: int,
extracted_features: list[EmmaExtractedFeatures],
) -> str:
"""Get the object class name from the visual token."""
# Get the class labels for the correct frame index
object_class_labels = extracted_features[frame_index - 1].entity_labels
if not object_class_labels:
raise AssertionError("Entity labels do not exist for features")
# Get the class label for the specified object
return object_class_labels[object_index - 1]
def get_correct_frame_index(
parsed_frame_index: int, num_frames_in_current_turn: int, num_total_frames: int
) -> int:
"""Get the correct frame index, considering the number of frames in the current turn."""
# Get the starting index frame for the current turn
start_frame_index = num_total_frames - num_frames_in_current_turn + 1
# Get the corrected frame index
frame_index = parsed_frame_index - start_frame_index
if num_frames_in_current_turn == 1 and frame_index != 0:
logger.warning(f"Predicted frame index: {frame_index} instead of 0.")
frame_index = 0
else:
# Make sure that the predicted frame index is between 0 and 3.
frame_index = min(max(frame_index, 0), 3)
return frame_index
def class_label_is_unique_in_frame(
frame_index: int,
class_label: str,
extracted_features: list[EmmaExtractedFeatures],
) -> bool:
"""Is there a unique object of the given class name in the frame?"""
# Get the class labels for the correct frame index
object_class_labels = extracted_features[frame_index - 1].entity_labels
if not object_class_labels:
raise AssertionError("Entity labels do not exist for features")
return object_class_labels.count(class_label) == 1