In [None]:
!pip install scikit-image
import os
import torch
import numpy as np
from torchvision import transforms
from PIL import Image, ImageEnhance
from skimage.util import random_noise

import math
import random
import matplotlib.pyplot as plt

IM_SIZE = 240

Collecting scikit-image
  Downloading scikit_image-0.18.1-cp37-cp37m-manylinux1_x86_64.whl (29.2 MB)
[K     |████████████████████████████████| 29.2 MB 21.5 MB/s 
[?25hCollecting networkx>=2.0
  Downloading networkx-2.5-py3-none-any.whl (1.6 MB)
[K     |████████████████████████████████| 1.6 MB 21.7 MB/s 
Collecting PyWavelets>=1.1.1
  Downloading PyWavelets-1.1.1-cp37-cp37m-manylinux1_x86_64.whl (4.4 MB)
[K     |████████████████████████████████| 4.4 MB 50.2 MB/s 
Collecting tifffile>=2019.7.26
  Downloading tifffile-2021.3.17-py3-none-any.whl (163 kB)
[K     |████████████████████████████████| 163 kB 15.1 MB/s 
[?25hCollecting imageio>=2.3.0
  Downloading imageio-2.9.0-py3-none-any.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 15.7 MB/s 
Installing collected packages: networkx, PyWavelets, tifffile, imageio, scikit-image
Successfully installed PyWavelets-1.1.1 imageio-2.9.0 networkx-2.5 scikit-image-0.18.1 tifffile-2021.3.17
You should consider upgrading via the '/

In [None]:
'''
    Definition of the horizontal flip function

    Flips one image

    Params
    image: PIL.Image.Image object or torch.Tensor will both work

    Returns
        a flipped version of a PIL Image or Tensor along the horizontal axis
'''

def horizontal_flip(image):
    was_pil = False
    
    if isinstance(image, Image.Image):
        image = transforms.functional.to_tensor(image)
        was_pil = True
    elif isinstance(image, torch.Tensor) and len(image.shape) != 3:
        raise Exception("Cannot mirror multiple images at once. Input image(s) is not of the format CHW")
    
    image = torch.flip(image, (2,))

    if was_pil:
        image = transforms.functional.to_pil_image(image)

    return image

# img = Image.open('./Mushrooms/Agaricus/000_ePQknW8cTp8.jpg')
# flipped = horizontal_flip(img)

# plt.figure(1)
# plt.imshow(img)

# plt.figure(2)
# plt.imshow(flipped)

# img.save('./poggers.jpg')

In [None]:
'''
    Definition of the rotation function

    Rotates one image and crops out blank space that is created.
    Images are resized back to their original size
    NOTE: Images can undergo strech deformation if the input image is not square

    Params
    image: PIL.Image.Image object
    angle: rotation in degrees, leave empty for random between -60 and 60
    resize_image: set to True to resize the image back to its original size

    Returns
        a rotated version of the input image
'''

def rotate_by(image, angle=None, resize_image=False):
    w, h = image.size
    
    if angle is None:
        angle = random.randint(-60, 60)

    ########################################################################################################
    # algorithm from https://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders
    #     by: coproc

    if w <= 0 or h <= 0:
        return 0,0

    width_is_longer = w >= h
    side_long, side_short = (w,h) if width_is_longer else (h,w)

    # since the solutions for angle, -angle and 180-angle are all the same,
    # if suffices to look at the first quadrant and the absolute values of sin,cos:
    sin_a, cos_a = abs(math.sin(angle * math.pi / 180)), abs(math.cos(angle * math.pi / 180))
    if side_short <= 2.*sin_a*cos_a*side_long or abs(sin_a-cos_a) < 1e-10:
        # half constrained case: two crop corners touch the longer side,
        #   the other two corners are on the mid-line parallel to the longer line
        x = 0.5*side_short
        wr,hr = (x/sin_a,x/cos_a) if width_is_longer else (x/cos_a,x/sin_a)
    else:
        # fully constrained case: crop touches all 4 sides
        cos_2a = cos_a*cos_a - sin_a*sin_a
        wr,hr = (w*cos_a - h*sin_a)/cos_2a, (h*cos_a - w*sin_a)/cos_2a
    ########################################################################################################

    r = min(wr, hr)
    
    image = image.rotate(angle=angle).crop(((w-r)/2, (h-r)/2, (w+r)/2, (h+r)/2))
    
    if resize_image:
        image = image.resize((IM_SIZE, IM_SIZE))
    
    return image

# img = Image.open('./Mushrooms/Agaricus/000_ePQknW8cTp8.jpg')
# rotated = rotate_by(img, angle=None, resize_image=False)

# plt.figure(1)
# plt.imshow(img)

# plt.figure(2)
# plt.imshow(rotated)

In [None]:
'''
    Definition of image scaling function

    Scales and randomly crops image to the given zoom factor
    Resizes the image back to its original size

    Params
    image: PIL.Image.Image object
    zoom: zoom factor, float greater than 1. Do not specify for random in range [1.1, 1.5)
    resize_image: set to True to resize the image back to its original size

    Returns
        a randomly scaled version of the input image
'''

def random_scale(image, zoom=None, resize_image=False):
    if zoom is None:
        zoom = random.random() * 0.4 + 1.1
    
    if zoom <= 1:
        return image
    
    w, h = image.size
    wr, hr = w/zoom, h/zoom
    r = min(wr, hr)

    left = random.randint(0, int(w-r))
    top = random.randint(0, int(h-r))

    image = image.crop((left, top, left + r, top + r))
    if resize_image:
        image = image.resize((IM_SIZE, IM_SIZE))

    return image

# img = Image.open('./Mushrooms/Agaricus/000_ePQknW8cTp8.jpg')
# scaled = random_scale(img, zoom=None, resize_image=False)

# plt.figure(1)
# plt.imshow(img)

# plt.figure(2)
# plt.imshow(scaled)

In [None]:
# brightness
def augBrightness(image, factor):
  im_br = ImageEnhance.Brightness(image)
  return im_br.enhance(factor)

# contrast 
def augContrast(image, factor):
  im_ct = ImageEnhance.Contrast(image)
  return im_ct.enhance(factor)

# noise
def augNoise(image, var):
  np_arr_img = np.array(image)
  with_noise = random_noise(np_arr_img, var=var)
  return Image.fromarray(np.uint8(with_noise * 255))

# blur 
def augBlur(image, factor):
  im_bl = ImageEnhance.Sharpness(image)
  return im_bl.enhance(factor)

In [None]:
# there are 7 modifications, we want to make equal use of each

# directory for un-augmented image sets
ROOT_DIR = './Mush_train'
class_dirs = os.listdir(ROOT_DIR)

# directory for augmented image sets
AUG_DIR_1 = './Mushaugm1'
AUG_DIR_2 = './Mushaugm2'

# augmenting process
lim1 = 1000
lim2 = 2000

trans = [
    lambda im: augBrightness(im, 0.5),  # 0 darken
    lambda im: augBrightness(im, 2.0),  # 1 brighten
    lambda im: augContrast(im, 0.5),    # 2 contrast down
    lambda im: augContrast(im, 2.0),    # 3 contrast up
    lambda im: horizontal_flip(im),     # 4 horizontal flip
    lambda im: augNoise(im, 0.01),      # 5 low noise
    lambda im: augNoise(im, 0.02),      # 6 high noise
    lambda im: augBlur(im, 0.8),        # 7 low blur
    lambda im: augBlur(im, 0.7),        # 8 high blur
]

trans_orig = [
    lambda im: rotate_by(im, 45),  # 0 rotate 45
    lambda im: random_scale(im),   # 1 rand scale
    lambda im: rotate_by(im, -45), # 2 rotate -45
    lambda im: random_scale(im)    # 3 rand scale
]

In [None]:
# crop image to specific size 
def fix_image(image):
    width, height = image.size
    width -= width % 2
    height -= height % 2

    width_half = width // 2
    height_half = height // 2

    if width < height: # preserve width
        image = image.crop((0, height_half - width_half, width, height_half + width_half))
        
    else: # preserve height
        image = image.crop((width_half - height_half, 0, width_half + height_half, height)) 

    image = image.resize((IM_SIZE, IM_SIZE))
    return image

In [None]:
def copy(lim, root_dir, aug_dir):
  # make sure the new base folder exists
  if not os.path.isdir(aug_dir):
    os.mkdir(aug_dir)
  
  for cdir in class_dirs:
    cpath = root_dir + '/' + cdir
    wpath = aug_dir + '/' + cdir

    if not os.path.isdir(wpath):
      os.mkdir(wpath)

    im_paths = os.listdir(cpath)  # all the image filenames
    
    for i in range(len(im_paths)):
      if i >= lim:
        break
      
      # trans_each + 1 transforms on each of these images
      cim_path = cpath + '/' + im_paths[i]
      cw_path = wpath + '/' + im_paths[i]
      cim = Image.open(cim_path)
      cim.save(cw_path)

In [None]:
def fill(lim, aug_dir):
  # make sure the new base folder exists
  if not os.path.isdir(aug_dir):
    os.mkdir(aug_dir)
  
  all_transforms = trans + trans_orig
  trans_cnt = len(all_transforms)
  
  # Fill folders first using 240 x 240 images
  for cdir in class_dirs:
    cpath = ROOT_DIR + '/' + cdir # Source folder
    wpath = aug_dir + '/' + cdir  # Augmented folder path

    # Create augmented folder if it doesn't exist
    if not os.path.isdir(wpath):
      os.mkdir(wpath)

    im_paths = os.listdir(cpath)  # all the image filenames
    cls_cnt = len(im_paths)       # how many images are in this class
    
    exists_cnt = len(os.listdir(wpath))
    req = lim - exists_cnt           # how many new images we need 
    
    # now evenly distribute the required number among all images
    trans_each = req // cls_cnt
    on_last = req % cls_cnt
    at_trans = 0
    cnt_viewed = 0

    for i in range(on_last):
      # trans_each + 1 transforms on each of these images
      cim_path = cpath + '/' + im_paths[i]
      cw_path = wpath + '/' + im_paths[i]
      cim = Image.open(cim_path)
      
      # save transformed images with transform id suffix
      for ti in range(trans_each + 1):
        new_image = None

        if at_trans < len(trans):
          new_image = all_transforms[at_trans](fix_image(cim))
        else: 
          new_image = fix_image(all_transforms[at_trans](cim))

        save_name = '.'.join(cw_path.split('.')[:-1]) + '__' + str(at_trans) + '.' + cw_path.split('.')[-1]
        new_image.save(save_name)
        at_trans = (at_trans + 1) % trans_cnt

    for i in range(on_last, cls_cnt):
      # trans_each + 1 transforms on each of these images
      cim_path = cpath + '/' + im_paths[i]
      cw_path = wpath + '/' + im_paths[i]
      cim = Image.open(cim_path)

      # save transformed images with transform id suffix
      for ti in range(trans_each):
        new_image = None

        if at_trans < len(trans):
          new_image = all_transforms[at_trans](fix_image(cim))
        else: 
          new_image = fix_image(all_transforms[at_trans](cim))

        save_name = '.'.join(cw_path.split('.')[:-1]) + '__' + str(at_trans) + '.' + cw_path.split('.')[-1]
        new_image.save(save_name)
        at_trans = (at_trans + 1) % trans_cnt

In [None]:
# Copy the images to the new augmented directory
copy(lim1, ROOT_DIR, AUG_DIR_1)

In [None]:
fill(lim1, AUG_DIR_1)

In [None]:
for cdir in os.listdir(AUG_DIR_1):
  print(cdir, ':', len(os.listdir(AUG_DIR_1 + '/' + cdir)))

Russula : 1000
Boletus : 1000
Agaricus : 1000
Amanita : 1000
Cortinarius : 1000
Lactarius : 1000
Hygrocybe : 1000
Entoloma : 1000
Suillus : 1000


In [None]:
copy(lim2, ROOT_DIR, AUG_DIR_2)
fill(lim2, AUG_DIR_2)

In [None]:
for cdir in os.listdir(AUG_DIR_2):
  print(cdir, ':', len(os.listdir(AUG_DIR_2 + '/' + cdir)))

Suillus : 2000
Entoloma : 2000
Russula : 2000
Cortinarius : 2000
Boletus : 2000
Hygrocybe : 2000
Lactarius : 2000
Amanita : 2000
Agaricus : 2000


<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=88b4a261-3cf5-4bb1-820f-4791ebb8a30d' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>