Skip to content
This repository has been archived by the owner on Apr 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #151 from mys007/randomcrop_fix
Browse files Browse the repository at this point in the history
Random crop fix
  • Loading branch information
nasimrahaman committed Nov 3, 2018
2 parents 8487b6a + ac48a81 commit 556ca8d
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions inferno/io/transform/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,12 @@ def clear_random_variables(self):
super(RandomCrop, self).clear_random_variables()

def build_random_variables(self, height_leeway, width_leeway):
self.set_random_variable('height_location',
np.random.randint(low=0, high=height_leeway + 1))
self.set_random_variable('width_location',
np.random.randint(low=0, high=width_leeway + 1))
if height_leeway > 0:
self.set_random_variable('height_location',
np.random.randint(low=0, high=height_leeway + 1))
if width_leeway > 0:
self.set_random_variable('width_location',
np.random.randint(low=0, high=width_leeway + 1))

def image_function(self, image):
# Validate image shape
Expand All @@ -150,6 +152,7 @@ def image_function(self, image):
height_leeway=height_leeway,
width_leeway=width_leeway)
cropped = image[height_location:(height_location + crop_height), :]
assert cropped.shape[0] == self.output_image_shape[0], "Well, shit."
else:
cropped = image
if width_leeway > 0:
Expand All @@ -158,7 +161,7 @@ def image_function(self, image):
height_leeway=height_leeway,
width_leeway=width_leeway)
cropped = cropped[:, width_location:(width_location + crop_width)]
assert cropped.shape == self.output_image_shape, "Well, shit."
assert cropped.shape[1] == self.output_image_shape[1], "Well, shit."
return cropped


Expand Down Expand Up @@ -448,11 +451,14 @@ def __init__(self, size, **super_kwargs):
def image_function(self, image):
h, w = image.shape
th, tw = self.size
x1 = int(round((w - tw) / 2.))
y1 = int(round((h - th) / 2.))
return image[x1:x1 + tw, y1:y1 + th]
if h > th:
y1 = int(round((h - th) / 2.))
image = image[:, y1:y1 + th]
if w > tw:
x1 = int(round((w - tw) / 2.))
image = image[x1:x1 + tw, :]
return image

trafo.image.RandomRotate(),

class BinaryMorphology(Transform):
"""
Expand Down

0 comments on commit 556ca8d

Please sign in to comment.