This repository has been archived by the owner on Jul 2, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 305
/
transforms.py
130 lines (109 loc) · 4.96 KB
/
transforms.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import division
import math
import numpy as np
def convolution_crop(img, size, stride, return_param=False):
"""Strided cropping.
This extracts cropped images from the input. The cropped images are
extracted from the entire image, while taking a constant steps between
neighboring patches.
Args:
img (~numpy.ndarray): An image array to be cropped. This is in
CHW format.
size (tuple): The size of output image after cropping.
This value is :math:`(height, width)`.
stride (tuple): The stride between crops. This contains
two values: stride in the vertical and horizontal directions.
return_param (bool): If :obj:`True`, this function returns
information of slices.
Returns:
~numpy.ndarray or (~numpy.ndarray, dict):
If :obj:`return_param = False`,
returns an array :obj:`crop_imgs` that is a stack of cropped images.
If :obj:`return_param = True`,
returns a tuple whose elements are :obj:`crop_imgs, param`.
:obj:`param` is a dictionary of intermediate parameters whose
contents are listed below with key, value-type and the description
of the value.
* **y_slices** (*list slices*): Slices used to crop the input image.\
The relation below holds together with :obj:`x_slices`.
* **x_slices** (*list of slices*): Similar to :obj:`y_slices`.
* **crop_y_slices** (*list of slices*): This indicates the region of \
the cropped image that is actually extracted from the input. \
This is relevant only when borders of the input are cropped.
* **crop_x_slices** (*list of slices*): Similar to \
:obj:`crop_y_slices`.
.. code::
crop_img = crop_imgs[i][:, crop_y_slices[i], crop_x_slices[i]]
crop_img == img[:, y_slices[i], x_slices[i]]
Examples:
>>> import numpy as np
>>> from chainercv.datasets import VOCBboxDataset
>>> from chainercv.transforms import resize
>>> from chainercv.experimental.links.model.pspnet import \
... convolution_crop
>>>
>>> img, _, _ = VOCBboxDataset(year='2007')[0]
>>> img = resize(img, (300, 300))
>>> imgs, param = convolution_crop(
>>> img, (128, 128), (96, 96), return_param=True)
>>> # Restore the original image from the cropped images.
>>> output = np.zeros((3, 300, 300))
>>> count = np.zeros((300, 300))
>>> for i in range(len(imgs)):
>>> crop_y_slice = param['crop_y_slices'][i]
>>> crop_x_slice = param['crop_x_slices'][i]
>>> y_slice = param['y_slices'][i]
>>> x_slice = param['x_slices'][i]
>>> output[:, y_slice, x_slice] +=\
... imgs[i][:, crop_y_slice, crop_x_slice]
>>> count[y_slice, x_slice] += 1
>>> output = output / count[None]
>>> np.testing.assert_equal(output, img)
>>>
>>> # Visualization of the cropped images
>>> import matplotlib.pyplot as plt
>>> from chainercv.utils import tile_images
>>> from chainercv.visualizations import vis_image
>>> v_imgs = tile_images(imgs, 5, fill=122.5)
>>> vis_image(v_imgs)
>>> plt.show()
"""
_, H, W = img.shape
h = int(math.ceil((H - size[0]) / stride[0])) + 1
w = int(math.ceil((W - size[1]) / stride[1])) + 1
start_y = -(size[0] + stride[0] * (h - 1) - H) // 2
start_x = -(size[1] + stride[1] * (w - 1) - W) // 2
crop_imgs = []
y_slices = []
x_slices = []
crop_y_slices = []
crop_x_slices = []
for y in range(h):
for x in range(w):
y_min = y * stride[0] + start_y
x_min = x * stride[1] + start_x
y_max = y_min + size[0]
x_max = x_min + size[1]
crop_y_min = np.abs(np.minimum(y_min, 0))
crop_x_min = np.abs(np.minimum(x_min, 0))
crop_y_max = size[0] - np.maximum(y_max - H, 0)
crop_x_max = size[1] - np.maximum(x_max - W, 0)
crop_img = np.zeros((img.shape[0], size[0], size[1]),
dtype=img.dtype)
y_slice = slice(max(y_min, 0), min(y_max, H))
x_slice = slice(max(x_min, 0), min(x_max, W))
crop_y_slice = slice(crop_y_min, crop_y_max)
crop_x_slice = slice(crop_x_min, crop_x_max)
crop_img[:, crop_y_slice, crop_x_slice] = img[:, y_slice, x_slice]
crop_imgs.append(crop_img)
y_slices.append(y_slice)
x_slices.append(x_slice)
crop_y_slices.append(crop_y_slice)
crop_x_slices.append(crop_x_slice)
if return_param:
param = {'y_slices': y_slices, 'x_slices': x_slices,
'crop_y_slices': crop_y_slices,
'crop_x_slices': crop_x_slices}
return np.array(crop_imgs), param
else:
return np.array(crop_imgs)