-
Notifications
You must be signed in to change notification settings - Fork 37
/
transforms.py
196 lines (168 loc) · 6.74 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from typing import Sequence
import torch
import torch.nn as nn
class DepthNorm(nn.Module):
"""
Normalize the depth channel: in an RGBD input of shape (4, H, W),
only the last channel is modified.
The depth channel is also clamped at 0.0. The Midas depth prediction
model outputs inverse depth maps - negative values correspond
to distances far away so can be clamped at 0.0
"""
def __init__(
self,
max_depth: float,
clamp_max_before_scale: bool = False,
min_depth: float = 0.01,
):
"""
Args:
max_depth (float): The max value of depth for the dataset
clamp_max (bool): Whether to clamp to max_depth or to divide by max_depth
"""
super().__init__()
if max_depth < 0.0:
raise ValueError("max_depth must be > 0; got %.2f" % max_depth)
self.max_depth = max_depth
self.clamp_max_before_scale = clamp_max_before_scale
self.min_depth = min_depth
def forward(self, image: torch.Tensor):
C, H, W = image.shape
if C != 4:
err_msg = (
f"This transform is for 4 channel RGBD input only; got {image.shape}"
)
raise ValueError(err_msg)
color_img = image[:3, ...] # (3, H, W)
depth_img = image[3:4, ...] # (1, H, W)
# Clamp to 0.0 to prevent negative depth values
depth_img = depth_img.clamp(min=self.min_depth)
# divide by max_depth
if self.clamp_max_before_scale:
depth_img = depth_img.clamp(max=self.max_depth)
depth_img /= self.max_depth
img = torch.cat([color_img, depth_img], dim=0)
return img
class TemporalCrop(nn.Module):
"""
Convert the video into smaller clips temporally.
"""
def __init__(self, frames_per_clip: int = 8, stride: int = 8):
super().__init__()
self.frames = frames_per_clip
self.stride = stride
def forward(self, video):
assert video.ndim == 4, "Must be (C, T, H, W)"
res = []
for start in range(0, video.size(1) - self.frames + 1, self.stride):
res.append(video[:, start : start + self.frames, ...])
return res
class SpatialCrop(nn.Module):
"""
Convert the video into 3 smaller clips spatially. Must be used after the
temporal crops to get spatial crops, and should be used with
-2 in the spatial crop at the slowfast augmentation stage (so full
frames are passed in here). Will return a larger list with the
3x spatial crops as well. It's useful for 3x4 testing (eg in SwinT)
or 3x10 testing in SlowFast etc.
"""
def __init__(self, crop_size: int = 224, num_crops: int = 3):
super().__init__()
self.crop_size = crop_size
if num_crops == 3:
self.crops_to_ext = [0, 1, 2]
elif num_crops == 1:
self.crops_to_ext = [1]
else:
raise NotImplementedError(
"Nothing else supported yet, "
"slowfast only takes 0, 1, 2 as arguments"
)
def forward(self, videos: Sequence[torch.Tensor]):
"""
Args:
videos: A list of C, T, H, W videos.
Returns:
videos: A list with 3x the number of elements. Each video converted
to C, T, H', W' by spatial cropping.
"""
assert isinstance(videos, list), "Must be a list of videos after temporal crops"
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
res = []
for video in videos:
for spatial_idx in self.crops_to_ext:
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
return res
def crop_boxes(boxes, x_offset, y_offset):
"""
Peform crop on the bounding boxes given the offsets.
Args:
boxes (ndarray or None): bounding boxes to peform crop. The dimension
is `num boxes` x 4.
x_offset (int): cropping offset in the x axis.
y_offset (int): cropping offset in the y axis.
Returns:
cropped_boxes (ndarray or None): the cropped boxes with dimension of
`num boxes` x 4.
"""
cropped_boxes = boxes.copy()
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
return cropped_boxes
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
"""
Perform uniform spatial sampling on the images and corresponding boxes.
Args:
images (tensor): images to perform uniform crop. The dimension is
`num frames` x `channel` x `height` x `width`.
size (int): size of height and weight to crop the images.
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
is larger than height. Or 0, 1, or 2 for top, center, and bottom
crop if height is larger than width.
boxes (ndarray or None): optional. Corresponding boxes to images.
Dimension is `num boxes` x 4.
scale_size (int): optinal. If not None, resize the images to scale_size before
performing any crop.
Returns:
cropped (tensor): images with dimension of
`num frames` x `channel` x `size` x `size`.
cropped_boxes (ndarray or None): the cropped boxes with dimension of
`num boxes` x 4.
"""
assert spatial_idx in [0, 1, 2]
ndim = len(images.shape)
if ndim == 3:
images = images.unsqueeze(0)
height = images.shape[2]
width = images.shape[3]
if scale_size is not None:
if width <= height:
width, height = scale_size, int(height / width * scale_size)
else:
width, height = int(width / height * scale_size), scale_size
images = torch.nn.functional.interpolate(
images, size=(height, width), mode="bilinear", align_corners=False
)
y_offset = int(math.ceil((height - size) / 2))
x_offset = int(math.ceil((width - size) / 2))
if height > width:
if spatial_idx == 0:
y_offset = 0
elif spatial_idx == 2:
y_offset = height - size
else:
if spatial_idx == 0:
x_offset = 0
elif spatial_idx == 2:
x_offset = width - size
cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size]
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
if ndim == 3:
cropped = cropped.squeeze(0)
return cropped, cropped_boxes