-
Notifications
You must be signed in to change notification settings - Fork 129
/
flava_transform.py
315 lines (274 loc) · 11.8 KB
/
flava_transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import math
import random
import warnings
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
import numpy as np
import torch
import torchvision.transforms.functional as F
from PIL import Image
from torch import Tensor
from torchvision import transforms
IMAGE_PRETRAINING_MEAN = (0.48145466, 0.4578275, 0.40821073)
IMAGE_PRETRAINING_STD = (0.26862954, 0.26130258, 0.27577711)
LOGIT_LAPLACE_EPS: float = 0.1
def map_pixels(x: torch.Tensor) -> torch.Tensor:
if x.dtype != torch.float:
raise ValueError("expected input to have type float")
return (1 - 2 * LOGIT_LAPLACE_EPS) * x + LOGIT_LAPLACE_EPS
class ImageMaskingGenerator:
def __init__(
self,
input_size: Union[Tuple[int, int], int],
num_masking_patches: int,
min_num_patches: int = 4,
max_num_patches: Optional[int] = None,
min_aspect: float = 0.3,
max_aspect: Optional[float] = None,
) -> None:
if not isinstance(input_size, tuple):
input_size = (input_size,) * 2
self.height, self.width = input_size
self.num_patches = self.height * self.width
self.num_masking_patches = num_masking_patches
self.min_num_patches = min_num_patches
self.max_num_patches = (
num_masking_patches if max_num_patches is None else max_num_patches
)
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
def __repr__(self) -> str:
repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % (
self.height,
self.width,
self.min_num_patches,
self.max_num_patches,
self.num_masking_patches,
self.log_aspect_ratio[0],
self.log_aspect_ratio[1],
)
return repr_str
def get_shape(self) -> Tuple[int, int]:
return self.height, self.width
def _mask(self, mask: np.ndarray, max_mask_patches: int) -> int:
delta = 0
for _attempt in range(10):
target_area = random.uniform(self.min_num_patches, max_mask_patches)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < self.width and h < self.height:
top = random.randint(0, self.height - h)
left = random.randint(0, self.width - w)
num_masked = mask[top : top + h, left : left + w].sum()
# Overlap
if 0 < h * w - num_masked <= max_mask_patches:
for i in range(top, top + h):
for j in range(left, left + w):
if mask[i, j] == 0:
mask[i, j] = 1
delta += 1
if delta > 0:
break
return delta
def __call__(self) -> np.ndarray:
mask = np.zeros(shape=self.get_shape(), dtype=np.int64) # type: ignore
mask_count = 0
while mask_count < self.num_masking_patches:
max_mask_patches = self.num_masking_patches - mask_count
max_mask_patches = min(max_mask_patches, self.max_num_patches)
delta = self._mask(mask, max_mask_patches)
if delta == 0:
break
else:
mask_count += delta
return mask
class TwoWayResize(transforms.Resize):
def __init__(
self,
size: Union[int, Tuple[int, int]],
second_size: Optional[Union[int, Tuple[int, int]]] = None,
second_interpolation: transforms.InterpolationMode = transforms.InterpolationMode.LANCZOS,
**kwargs: Any,
) -> None:
if not isinstance(size, (list, tuple)):
size = (size, size)
super().__init__(size, **kwargs)
# Backward compatibility with integer value
if isinstance(second_interpolation, int):
warnings.warn(
"Argument second_interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
second_interpolation = transforms._interpolation_modes_from_int(
second_interpolation
)
if not isinstance(second_size, (list, tuple)):
second_size = (second_size, second_size)
self.second_size = second_size
self.second_interpolation = second_interpolation
def forward(self, img: Image.Image) -> Tuple[Image.Image, Image.Image]:
img = F.resize(
img, self.size, self.interpolation, self.max_size, self.antialias
)
second_img = F.resize(
img,
self.second_size,
self.second_interpolation,
self.max_size,
self.antialias,
)
return img, second_img
class TwoWayRandomResizedCrop(transforms.RandomResizedCrop):
"""
Similar to RandomResizedCrop but returns two versions of the
random crop with different sizings and interpolations.
Note that the crop is same but the two returned images
have different final sizes and interpolations
"""
def __init__(
self,
size: Union[int, Tuple[int, int]],
second_size: Optional[Union[int, Tuple[int, int]]] = None,
second_interpolation: transforms.InterpolationMode = transforms.InterpolationMode.LANCZOS,
**kwargs: Any,
) -> None:
super().__init__(size, **kwargs)
# Backward compatibility with integer value
if isinstance(second_interpolation, int):
warnings.warn(
"Argument second_interpolation should be of type InterpolationMode instead of int. "
"Please, use InterpolationMode enum."
)
second_interpolation = transforms._interpolation_modes_from_int(
second_interpolation
)
if not isinstance(second_size, (list, tuple)):
second_size = (second_size, second_size)
self.second_size = second_size
self.second_interpolation = second_interpolation
def __call__(
self, img: Image.Image
) -> Union[Image.Image, Tuple[Image.Image, Image.Image]]:
i, j, h, w = self.get_params(img, self.scale, self.ratio)
if isinstance(self.interpolation, (tuple, list)):
interpolation = random.choice(self.interpolation)
else:
interpolation = self.interpolation
if self.second_size is None:
return F.resized_crop(img, i, j, h, w, self.size, interpolation)
else:
return (
F.resized_crop(img, i, j, h, w, self.size, interpolation),
F.resized_crop(
img, i, j, h, w, self.second_size, self.second_interpolation
),
)
class FLAVAImageTransform:
"""FLAVA image transform which does basic transforms like resize etc on images,
randomly masks patches in an image based on scheme from Beit https://arxiv.org/pdf/2106.08254.pdf
and generates codebook tokens
Args:
is_train (bool): whether transform is applied during training or not. Random crop and interpolation is enabled for training.
Defaults to True.
encoder_input_size (int): size of image that is input to the image encoder. Default is 224.
codebook_input_size (int): size of image that is input to the visual codebook. Default is 112.
scale (Tuple[float, float]): scale passed to RandomResizedCrop transform. Default is 112.
encoder_interpolation(str): interpolation for RandomResizedCrop or Resize transform for image passed to encoder.\
Default is BICUBIC
codebook_interpolation(str): interpolation for RandomResizedCrop or Resize transform for image passed to visual codebook. \
Default is LANCZOS
image_mean (Tuple[float, float, float]): mean for image normalization. Default is (0.48145466, 0.4578275, 0.40821073)
image_std (Tuple[float, float, float]): standard deviation for image normalization. \
Default is (0.26862954, 0.26130258, 0.27577711)
mask_window_size (int): dimension of mask. Default is 14.
mask_num_patches (int): number of patches to mask. Default is 75.
mask_max_patches (Optional[int]): max number of patches to mask. Default is None.
mask_min_patches (int): min number of patches to mask. Default is 16.
Inputs:
images (Union[List[Image.Image], Image.Image]): input image / list of images
"""
def __init__(
self,
is_train: bool = True,
encoder_input_size: int = 224,
codebook_input_size: int = 112,
scale: Tuple[float, float] = (0.9, 1.0),
encoder_interpolation: str = transforms.InterpolationMode.BICUBIC,
codebook_interpolation: str = transforms.InterpolationMode.LANCZOS,
image_mean: Tuple[float, float, float] = IMAGE_PRETRAINING_MEAN,
image_std: Tuple[float, float, float] = IMAGE_PRETRAINING_STD,
mask_window_size: int = 14,
mask_num_patches: int = 75,
mask_max_patches: Optional[int] = None,
mask_min_patches: int = 16,
) -> None:
if is_train:
resize_func = TwoWayRandomResizedCrop(
size=encoder_input_size,
second_size=codebook_input_size,
scale=scale,
interpolation=encoder_interpolation,
second_interpolation=codebook_interpolation,
)
else:
resize_func = TwoWayResize(
size=encoder_input_size,
second_size=codebook_input_size,
interpolation=encoder_interpolation,
second_interpolation=codebook_interpolation,
)
self.common_transform = transforms.Compose(
[
transforms.Lambda(
lambda img: img.convert("RGB") if img.mode != "RGB" else img
),
resize_func,
]
)
self.image_transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(image_mean),
std=torch.tensor(image_std),
),
]
)
self.codebook_transform = transforms.Compose(
[
transforms.ToTensor(),
map_pixels,
]
)
self.masked_position_generator = ImageMaskingGenerator(
mask_window_size,
num_masking_patches=mask_num_patches,
max_num_patches=mask_max_patches,
min_num_patches=mask_min_patches,
)
def transform(self, image: Image.Image) -> Dict[str, Tensor]:
image, image_for_codebook = self.common_transform(image)
return {
"image": self.image_transform(image),
"image_for_codebook": self.codebook_transform(image_for_codebook),
"image_patches_mask": torch.from_numpy(self.masked_position_generator()),
}
def __call__(
self, images: Union[List[Image.Image], Image.Image]
) -> Mapping[str, Union[Tensor, List[Tensor]]]:
if isinstance(images, list):
output: Dict[str, List[Tensor]] = {}
for image in images:
transformed_output = self.transform(image)
for key in transformed_output:
if key not in output:
output[key] = []
output[key].append(transformed_output[key])
return output
else:
return self.transform(images)