Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 132 additions & 36 deletions captum/optim/_param/image/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,59 +178,99 @@ def forward(self, x: torch.Tensor, inverse: bool = False) -> torch.Tensor:

class CenterCrop(torch.nn.Module):
"""
Center crop a specified amount from a tensor.
Center crop a specified amount from a tensor. If input are smaller than the
specified crop size, padding will be applied.
"""

__constants__ = [
"size",
"pixels_from_edges",
"offset_left",
"padding_mode",
"padding_value",
]

def __init__(
self,
size: IntSeqOrIntType = 0,
pixels_from_edges: bool = False,
offset_left: bool = False,
padding_mode: str = "constant",
padding_value: float = 0.0,
) -> None:
"""
Args:

size (int, sequence, int): Number of pixels to center crop away.
pixels_from_edges (bool, optional): Whether to treat crop size
values as the number of pixels from the tensor's edge, or an
exact shape in the center.
pixels_from_edges (bool, optional): Whether to treat crop size
values as the number of pixels from the tensor's edge, or an
exact shape in the center.
Default: False
offset_left (bool, optional): If the cropped away sides are not
equal in size, offset center by +1 to the left and/or top.
This parameter is only valid when `pixels_from_edges` is False.
Default: False
padding_mode (optional, str): One of "constant", "reflect", "replicate"
or "circular". This parameter is only used if the crop size is larger
than the image size.
Default: "constant"
padding_value (float, optional): fill value for "constant" padding. This
parameter is only used if the crop size is larger than the image size.
Default: 0.0
"""
super().__init__()
self.crop_vals = size
if not hasattr(size, "__iter__"):
size = [int(size), int(size)]
elif isinstance(size, (tuple, list)):
if len(size) == 1:
size = list((size[0], size[0]))
elif len(size) == 2:
size = list(size)
else:
raise ValueError("Crop size length of {} too large".format(len(size)))
else:
raise ValueError("Unsupported crop size value {}".format(size))
assert len(size) == 2
self.size = cast(List[int], size)
self.pixels_from_edges = pixels_from_edges
self.offset_left = offset_left
self.padding_mode = padding_mode
self.padding_value = padding_value

@torch.jit.ignore
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Center crop an input.

Args:
input (torch.Tensor): Input to center crop.

Returns:
**tensor** (torch.Tensor): A center cropped *tensor*.
"""

return center_crop(
input, self.crop_vals, self.pixels_from_edges, self.offset_left
input,
self.size,
self.pixels_from_edges,
self.offset_left,
self.padding_mode,
self.padding_value,
)


def center_crop(
input: torch.Tensor,
crop_vals: IntSeqOrIntType,
size: Union[int, List[int]],
pixels_from_edges: bool = False,
offset_left: bool = False,
padding_mode: str = "constant",
padding_value: float = 0.0,
) -> torch.Tensor:
"""
Center crop a specified amount from a tensor.

Center crop a specified amount from a tensor. If input are smaller than the
specified crop size, padding will be applied.
Args:

input (tensor): A CHW or NCHW image tensor to center crop.
size (int, sequence, int): Number of pixels to center crop away.
pixels_from_edges (bool, optional): Whether to treat crop size
Expand All @@ -241,35 +281,63 @@ def center_crop(
equal in size, offset center by +1 to the left and/or top.
This parameter is only valid when `pixels_from_edges` is False.
Default: False

padding_mode (optional, str): One of "constant", "reflect", "replicate" or
"circular". This parameter is only used if the crop size is larger than
the image size.
Default: "constant"
padding_value (float, optional): fill value for "constant" padding. This
parameter is only used if the crop size is larger than the image size.
Default: 0.0
Returns:
**tensor**: A center cropped *tensor*.
"""

assert input.dim() == 3 or input.dim() == 4
crop_vals = [crop_vals] * 2 if not hasattr(crop_vals, "__iter__") else crop_vals
crop_vals = list(crop_vals) * 2 if len(crop_vals) == 1 else crop_vals
crop_vals = cast(Union[List[int], Tuple[int, int]], crop_vals)
assert len(crop_vals) == 2
if isinstance(size, int):
size = [int(size), int(size)]
elif isinstance(size, (tuple, list)):
if len(size) == 1:
size = [size[0], size[0]]
elif len(size) == 2:
size = list(size)
else:
raise ValueError("Crop size length of {} too large".format(len(size)))
else:
raise ValueError("Unsupported crop size value {}".format(size))
assert len(size) == 2

if input.dim() == 4:
h, w = input.size(2), input.size(3)
if input.dim() == 3:
h, w = input.size(1), input.size(2)
h, w = input.shape[2:]
elif input.dim() == 3:
h, w = input.shape[1:]
else:
raise ValueError("Input has too many dimensions: {}".format(input.dim()))

if pixels_from_edges:
h_crop = h - crop_vals[0]
w_crop = w - crop_vals[1]
h_crop = h - size[0]
w_crop = w - size[1]
sw, sh = w // 2 - (w_crop // 2), h // 2 - (h_crop // 2)
x = input[..., sh : sh + h_crop, sw : sw + w_crop]
else:
h_crop = h - int(math.ceil((h - crop_vals[0]) / 2.0))
w_crop = w - int(math.ceil((w - crop_vals[1]) / 2.0))
if h % 2 == 0 and crop_vals[0] % 2 != 0 or h % 2 != 0 and crop_vals[0] % 2 == 0:
h_crop = h - int(math.ceil((h - size[0]) / 2.0)) if h > size[0] else size[0]
w_crop = w - int(math.ceil((w - size[1]) / 2.0)) if w > size[1] else size[1]

if h % 2 == 0 and size[0] % 2 != 0 or h % 2 != 0 and size[0] % 2 == 0:
h_crop = h_crop + 1 if offset_left else h_crop
if w % 2 == 0 and crop_vals[1] % 2 != 0 or w % 2 != 0 and crop_vals[1] % 2 == 0:
if w % 2 == 0 and size[1] % 2 != 0 or w % 2 != 0 and size[1] % 2 == 0:
w_crop = w_crop + 1 if offset_left else w_crop
x = input[..., h_crop - crop_vals[0] : h_crop, w_crop - crop_vals[1] : w_crop]

if size[0] > h or size[1] > w:
# Padding functionality like Torchvision's center crop
padding = [
math.ceil((size[1] - w) / 2) if size[1] > w else 0,
math.ceil((size[0] - h) / 2) if size[0] > h else 0,
(size[1] - w + 1) // 2 if size[1] > w else 0,
(size[0] - h + 1) // 2 if size[0] > h else 0,
]
input = F.pad(input, padding, mode=padding_mode, value=padding_value)

x = input[..., h_crop - size[0] : h_crop, w_crop - size[1] : w_crop]
return x


Expand Down Expand Up @@ -779,13 +847,14 @@ class RandomCrop(nn.Module):
Randomly crop out a specific size from an NCHW image tensor.
"""

__constants__ = ["crop_size"]

def __init__(
self,
crop_size: IntSeqOrIntType,
) -> None:
"""
Args:

crop_size (int, sequence, int): The desired cropped output size.
"""
super().__init__()
Expand All @@ -795,20 +864,47 @@ def __init__(
assert len(crop_size) == 2
self.crop_size = crop_size

def _center_crop(self, x: torch.Tensor) -> torch.Tensor:
"""
Center crop an NCHW image tensor based on self.crop_size.
Args:
x (torch.Tensor): The NCHW image tensor to center crop.
Returns
x (torch.Tensor): The center cropped NCHW image tensor.
"""
h, w = x.shape[2:]
h_crop = h - int(math.ceil((h - self.crop_size[0]) / 2.0))
w_crop = w - int(math.ceil((w - self.crop_size[1]) / 2.0))
return x[
...,
h_crop - self.crop_size[0] : h_crop,
w_crop - self.crop_size[1] : w_crop,
]

def forward(self, x: torch.Tensor) -> torch.Tensor:
assert x.dim() == 4
hs = x.shape[2] - self.crop_size[0]
ws = x.shape[3] - self.crop_size[1]
hs = int(math.ceil((x.shape[2] - self.crop_size[0]) / 2.0))
ws = int(math.ceil((x.shape[3] - self.crop_size[1]) / 2.0))
shifts = [
torch.randint(low=-hs, high=hs, size=[1]),
torch.randint(low=-ws, high=ws, size=[1]),
torch.randint(
low=-hs,
high=hs,
size=[1],
dtype=torch.int64,
layout=torch.strided,
device=x.device,
),
torch.randint(
low=-ws,
high=ws,
size=[1],
dtype=torch.int64,
layout=torch.strided,
device=x.device,
),
]
x = torch.roll(x, shifts, dims=(2, 3))
return center_crop(
x,
crop_vals=self.crop_size,
pixels_from_edges=False,
)
x = torch.roll(x, [int(s) for s in shifts], dims=(2, 3))
return self._center_crop(x)


__all__ = [
Expand Down
Loading