Skip to content

Commit

Permalink
Merge pull request #1894 from TaranSinghania/master
Browse files Browse the repository at this point in the history
Added Data Augmentation Support
  • Loading branch information
Bharath Ramsundar committed Jun 22, 2020
2 parents daa2a72 + bda1d7c commit 34aa0b4
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 2 deletions.
39 changes: 39 additions & 0 deletions deepchem/trans/tests/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,35 @@ def test_blurring(self):
check_blur = scipy.ndimage.gaussian_filter(self.d, 1.5)
assert np.allclose(check_blur, blurred)

def test_center_crop(self):
# Check center crop
dt = DataTransforms(self.d)
x_crop = 50
y_crop = 50
crop = dt.center_crop(x_crop, y_crop)
y = self.d.shape[0]
x = self.d.shape[1]
x_start = x // 2 - (x_crop // 2)
y_start = y // 2 - (y_crop // 2)
check_crop = self.d[y_start:y_start + y_crop, x_start:x_start + x_crop]
assert np.allclose(check_crop, crop)

def test_crop(self):
#Check crop
dt = DataTransforms(self.d)
crop = dt.crop(0, 10, 0, 10)
y = self.d.shape[0]
x = self.d.shape[1]
check_crop = self.d[10:y - 10, 0:x - 0]
assert np.allclose(crop, check_crop)

def test_convert2gray(self):
# Check convert2gray
dt = DataTransforms(self.d)
gray = dt.convert2gray()
check_gray = np.dot(self.d[..., :3], [0.2989, 0.5870, 0.1140])
assert np.allclose(check_gray, gray)

def test_rotation(self):
# Check rotation
dt = DataTransforms(self.d)
Expand Down Expand Up @@ -677,3 +706,13 @@ def test_DAG_transformer(self):
# atoms. These are denoted the "parents"
for idm, mol in enumerate(dataset.X):
assert dataset.X[idm].get_num_atoms() == len(dataset.X[idm].parents)

def test_median_filter(self):
#Check median filter
from PIL import Image, ImageFilter
dt = DataTransforms(self.d)
filtered = dt.median_filter(size=3)
image = Image.fromarray(self.d)
image = image.filter(ImageFilter.MedianFilter(size=3))
check_filtered = np.array(image)
assert np.allclose(check_filtered, filtered)
81 changes: 79 additions & 2 deletions deepchem/trans/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1306,8 +1306,15 @@ def flip(self, direction="lr"):

def rotate(self, angle=0):
""" Rotates the image
Parameters:
angle (default = 0 i.e no rotation) - Denotes angle by which the image should be rotated (in Degrees)
Parameters
----------
angle: float (default = 0 i.e no rotation)
Denotes angle by which the image should be rotated (in Degrees)
Returns
----------
The rotated imput array
"""
return scipy.ndimage.rotate(self.Image, angle)

Expand All @@ -1318,6 +1325,59 @@ def gaussian_blur(self, sigma=0.2):
"""
return scipy.ndimage.gaussian_filter(self.Image, sigma)

def center_crop(self, x_crop, y_crop):
""" Crops the image from the center
Parameters
----------
x_crop: int
the total number of pixels to remove in the horizontal direction, evenly split between the left and right sides
y_crop: int
the total number of pixels to remove in the vertical direction, evenly split between the top and bottom sides
Returns
----------
The center cropped input array
"""
y = self.Image.shape[0]
x = self.Image.shape[1]
x_start = x // 2 - (x_crop // 2)
y_start = y // 2 - (y_crop // 2)
return self.Image[y_start:y_start + y_crop, x_start:x_start + x_crop]

def crop(self, left, top, right, bottom):
""" Crops the image and returns the specified rectangular region from an image
Parameters
----------
left: int
the number of pixels to exclude from the left of the image
top: int
the number of pixels to exclude from the top of the image
right: int
the number of pixels to exclude from the right of the image
bottom: int
the number of pixels to exclude from the bottom of the image
Returns
----------
The cropped input array
"""
y = self.Image.shape[0]
x = self.Image.shape[1]
return self.Image[top:y - bottom, left:x - right]

def convert2gray(self):
""" Converts the image to grayscale. The coefficients correspond to the Y' component of the Y'UV color system.
Returns
----------
The grayscale image.
"""
return np.dot(self.Image[..., :3], [0.2989, 0.5870, 0.1140])

def shift(self, width, height, mode='constant', order=3):
"""Shifts the image
Parameters:
Expand Down Expand Up @@ -1358,3 +1418,20 @@ def salt_pepper_noise(self, prob=0.05, salt=255, pepper=0):
x[noise < (prob / 2)] = pepper
x[noise > (1 - prob / 2)] = salt
return x

def median_filter(self, size):
""" Calculates a multidimensional median filter
Parameters
----------
size: int
The kernel size in pixels.
Returns
----------
The median filtered image.
"""
from PIL import Image, ImageFilter
image = Image.fromarray(self.Image)
image = image.filter(ImageFilter.MedianFilter(size=size))
return np.array(image)

0 comments on commit 34aa0b4

Please sign in to comment.