## Balance Datasets ##

This notebook takes the tiles from Image Preparation and augments the dataset by masking random portions of the images.

In [1]:
import os
import shutil
from PIL import Image
import matplotlib.pyplot as plt

import random
import numpy as np


In [3]:
#Create directory structure
dirs=os.listdir('/kaggle/working/')

cats=['ImagesTrain','ImagesVal']
labels=['CC', 'EC', 'HGSC', 'LGSC', 'MC']

if 'ImagesTrain' in dirs:
    shutil.rmtree('/kaggle/working/ImagesTrain')
    
if 'ImagesVal' in dirs:
    shutil.rmtree('/kaggle/working/ImagesVal')
    
for label in labels:
    path='ImagesTrain/'+label
    os.makedirs(path, exist_ok=True)

    path='ImagesVal/'+label
    os.makedirs(path, exist_ok=True)

In [4]:
#copy folders to working dir.
for cat in cats:
    for label in labels:
        print(cat, label)
        srcpath='/kaggle/input/imageprep-07/'+cat+'/'+label+'/'
        dstpath='/kaggle/working/'+cat+'/'+label+'/'
        
        files=os.listdir(srcpath)
        for file in files:
            shutil.copy2(srcpath+file, dstpath+file)

ImagesTrain CC
ImagesTrain EC
ImagesTrain HGSC
ImagesTrain LGSC
ImagesTrain MC
ImagesVal CC
ImagesVal EC
ImagesVal HGSC
ImagesVal LGSC
ImagesVal MC


In [5]:
#Only if zipped
# !unzip -q '/kaggle/input/imageprep-07/_output_.zip'

## Augmentation ##
Images in the base dataset will be copied and modified until the target dataset sizes are met and the classes are balanced.  A few examples of augmentation include rotation, flipping images, and masking of vertical or horizontal regions of the image.  The size and location of each mask is chosen randomly.

If a class has too many images, they are pruned randomly, until the target size is met.

Some examples are below.

| Horzontal Mask | Vertical Mask |
|      :----:    |      :----:   |
| ![](CC_000001_a-a-01.png)  |  ![](CC_000001_a-b-02.png) |

In [6]:
#Function to mask random vertical or horizontal bands in the images, based on the Operation
def augment(path, num, files, oper):

    for idx, file in enumerate(files):
        if idx==num:
            break

        img=Image.open(path+file)
        
        #masking - 25-50% somewhere in the middle .375-.625
        #in pixels = 64-128, centered 96-160
        mwidth=np.random.randint(64, 129)
        mheight=np.random.randint(64, 129)
        mwidctr=np.random.randint(96, 161)
        mhtctr=np.random.randint(96, 161)

        if oper=='a':
            if idx%500==0: print('Choice, ', oper)
            img=img.rotate(90, expand=False)
            imgarray=np.array(img)
            #Row
            imgarray[(mhtctr-mheight//2):(mhtctr+mheight//2), :, :] = 0
            img=Image.fromarray(np.uint8(imgarray))
            
            name=file.split('.')[0] +'-'+ str(oper)+'-01.png'
            
        if oper=='b':
            if idx%500==0: print('Choice, ', oper)
            img=img.rotate(-90, expand=False)
            imgarray=np.array(img)
            #Column
            imgarray[:, (mwidctr-mwidth//2):(mwidctr+mwidth//2), :] = 0
            img=Image.fromarray(np.uint8(imgarray))
            
            name=file.split('.')[0] +'-'+ str(oper)+ '-02.png'
            
        if oper=='c':
            if idx%500==0: print('Choice, ', oper)
            img=img.transpose(method=Image.Transpose.FLIP_LEFT_RIGHT)
            imgarray=np.array(img)
            #Row and Col
            imgarray[(mhtctr-mheight//2):(mhtctr+mheight//2), :, :] = 0
            imgarray[:, (mwidctr-mwidth//2):(mwidctr+mwidth//2), :] = 0
            img=Image.fromarray(np.uint8(imgarray))
            
            name=file.split('.')[0] +'-'+ str(oper)+ '-03.png'
            
        if oper=='d':
            if idx%500==0: print('Choice, ', oper)
            img=img.transpose(method=Image.Transpose.FLIP_TOP_BOTTOM)
            imgarray=np.array(img)
            #Upper left quadrant
            imgarray[0:64, 0:64, :] = 0
            img=Image.fromarray(np.uint8(imgarray))
            
            name=file.split('.')[0] + '-'+ str(oper)+'-04.png'
            
        if oper=='e':
            if idx%500==0: print('Choice, ', oper)
            img=img.transpose(method=Image.Transpose.FLIP_LEFT_RIGHT)
            img=img.rotate(90, expand=False)
            imgarray=np.array(img)
            #Lower right quadrant
            imgarray[128:255, 128:255, :] = 0
            img=Image.fromarray(np.uint8(imgarray))
            
            name=file.split('.')[0] +'-'+ str(oper)+ '-05.png'
            
        if oper=='f':
            if idx%500==0: print('Choice, ', oper)
            img=img.transpose(method=Image.Transpose.FLIP_TOP_BOTTOM)
            img=img.rotate(90, expand=False)
            imgarray=np.array(img)
            #middle fat column
            imgarray[:, 64:192, :] = 0
            img=Image.fromarray(np.uint8(imgarray))
            
            name=file.split('.')[0] + '-'+ str(oper)+'-06.png'
        

        img.save(path+name)

In [7]:
#Function that removes image tiles if count exceeds target
def prune(path, num, files):
    
    for i in range(num):
        file=random.choice(files)
        files.remove(file)
        os.remove(path+file)

In [1]:
#Loop that balances classes and augments image sets unitl total image target is met

#total images
total=80000

#full, thumb
tartrain=[.128*total, .032*total]
tarval=[.032*total, .008*total]

#repeatable masking
np.random.seed(41)

for cat in cats:
    for label in labels:
        #fullsize
        path='/kaggle/working/'+cat+'/'+label+'/'
        files=os.listdir(path)
        full=[]
        thumb=[]
        
        #get base list of file types
        for file in files:
            if file.count('_')>1:
                full.append(file)
            else:
                thumb.append(file)
        
        #set targets
        if 'Val' in cat:
            targets=tarval
        else:
            targets=tartrain   
        
        #thumbs
        if targets[1]>len(thumb):
            #augment thumbs
            #list ['a','b','c','d','e','f'] represents augmentation type - flip, location of masking
            for idx in ['a','b','c','d','e','f']:
                files=os.listdir(path)
                thumbs=[]
                #get count of file types
                for file in files:
                    if file.count('_')==1:
                        thumbs.append(file)
                        
                if len(thumbs)>=targets[1]:
                    continue       
                    
                #update num, augment
                num=targets[1]-len(thumbs)
                #augment on original list
                augment(path, num, thumb, idx)
                
        else:
            #prune thumbs
            num = len(thumb) - targets[1]
            prune(path, num, thumb)
            
        #fullsize
        if targets[0]>len(full):
            #augment full
            #list ['a','b','c','d','e','f'] represents augmentation type - flip, location of masking
            for idx in ['a','b','c','d','e','f']:
                files=os.listdir(path)
                fulls=[]
                #get count of file types
                for file in files:
                    if file.count('_')>1:
                        fulls.append(file)

                if len(fulls)>=targets[0]:
                    continue   
                
                #update num
                num=targets[0]-len(fulls)
                print('Re-augment',idx,num)
                #augment on original list
                augment(path, num, full, idx)
                
        else:
            #prune full
            num = len(full) - targets[0]
            prune(path, num, full)

In [11]:
#Verify target sample counts are met
for cat in cats:
    print(cat)
    for label in labels:
        #fullsize
        path='/kaggle/working/'+cat+'/'+label+'/'
        files=os.listdir(path)
        print(label, len(files))

ImagesTrain
CC 12800
EC 12800
HGSC 12800
LGSC 12800
MC 12800
ImagesVal
CC 3200
EC 3200
HGSC 3200
LGSC 3200
MC 3200
