In [1]:
import os
import random
import glob
import shutil
import subprocess

from PIL import Image, ImageOps
from pyunpack import Archive

In [2]:
AOI_ZIP_PATH = '/root/aoi.zip'
EXTRACT_PATH = './data/raw_data'

def unzip():
    try:
        os.makedirs(EXTRACT_PATH)
    except FileExistsError:
        shutil.rmtree(EXTRACT_PATH)
        os.mkdir(EXTRACT_PATH)
    except:
        pass
    Archive(AOI_ZIP_PATH).extractall(EXTRACT_PATH)
    for path in glob.glob(''.join([EXTRACT_PATH, '/*.zip'])):
        Archive(path).extractall(EXTRACT_PATH)
        os.remove(path)

In [3]:
RAW_DATA_PATH = './data/raw_data/train_images'
DATA_PATH = './data/sep_data'
TRAIN_CSV_PATH = './data/raw_data/train.csv'
CLASS = list(map(str, range(6)))
split_rate = 0.2

def split():
    for tv in ['train', 'validation']:
        for idx in CLASS:
            try:
                os.makedirs('/'.join([DATA_PATH, tv, idx]))
            except FileExistsError:
                shutil.rmtree(DATA_PATH)
                os.makedirs('/'.join([DATA_PATH, tv, idx]))
    with open(TRAIN_CSV_PATH, 'r', newline='') as csv:
        csv.readline()
        for line in csv.readlines():
            ID, label = line.strip().split(',')
            if random.random() >= split_rate:
                tv = 'train'
            else:
                tv = 'validation'
            shutil.copy('/'.join([RAW_DATA_PATH, ID]), 
                        '/'.join([DATA_PATH, tv, label, ID]))

In [4]:
def count():
    print("Image counts:")
    for tv in ['train', 'validation']:
        print('   {}:'.format(tv))
        for idx in CLASS:
            abs_path = os.path.abspath('/'.join([DATA_PATH, tv, idx]))
            cnt = (bytearray(subprocess.run(['tree', abs_path], 
                                           stdout=subprocess.PIPE).stdout).decode().split(' ')[-2])
            print('{:>8}: {}'.format(idx, cnt))

In [5]:
def augmentation():
    tv = 'train'
    for idx in CLASS:
        img_paths = glob.glob('/'.join([DATA_PATH, tv, idx, '*.png']))
        imgs = []
        des_paths = []
        for index, img_path in enumerate(img_paths):
            img = Image.open(img_path)
            img_name = img_path.split('/')[-1].split('.')[-2]
            des_path = ''.join([img_name, '{}', '.png'])
            imgs.append(ImageOps.flip(img))
            des_paths.append(des_path.format('_flip_v'))
            imgs.append(ImageOps.mirror(img))
            des_paths.append(des_path.format('_flip_h'))
            angles = [180] if idx in ['2', '3'] else [90, 180, 270]
            for angle in angles:
                imgs.append(img.rotate(angle))
                des_paths.append(des_path.format('_'+str(angle)))
            if idx in list(map(str, range(2, 5))):
                imgs.append(ImageOps.flip(img))
                des_paths.append(des_path.format('_flip_v_2'))
                imgs.append(ImageOps.mirror(img))
                des_paths.append(des_path.format('_flip_h_2'))
                for angle in angles:
                    imgs.append(img.rotate(angle))
                    des_paths.append(des_path.format('_'+str(angle)+'_'))
            if index % 300 == 0:
                print(idx, index)
            if idx == '2':
                imgs.append(ImageOps.flip(img))
                des_paths.append(des_path.format('_flip_v_3'))
                imgs.append(ImageOps.mirror(img))
                des_paths.append(des_path.format('_flip_h_3'))
                for angle in angles:
                    imgs.append(img.rotate(angle))
                    des_paths.append(des_path.format('_'+str(angle)+'_2'))
                imgs.append(ImageOps.flip(img))
                des_paths.append(des_path.format('_flip_v_4'))
                imgs.append(ImageOps.mirror(img))
                des_paths.append(des_path.format('_flip_h_4'))
                for angle in angles:
                    imgs.append(img.rotate(angle))
                    des_paths.append(des_path.format('_'+str(angle)+'_3'))
        for img, des_path in zip(imgs, des_paths):
            des_path = '/'.join([DATA_PATH, tv, idx, des_path]) 
            img.save(des_path)

        
split()
augmentation()

0 0
0 300
1 0
1 300
2 0
3 0
4 0
5 0
5 300


In [6]:
count()

Image counts:
   train:
       0: 3168
       1: 2370
       2: 1053
       3: 2023
       4: 2200
       5: 3042
   validation:
       0: 146
       1: 97
       2: 19
       3: 89
       4: 40
       5: 137
