In [23]:
import os
import random
import glob
import shutil
import subprocess

from PIL import Image, ImageOps
from pyunpack import Archive

In [42]:
AOI_ZIP_PATH = '/root/aoi.zip'
EXTRACT_PATH = '/root/aoi'

def unzip():
    try:
        os.makedirs(EXTRACT_PATH)
    except FileExistsError:
        shutil.rmtree(EXTRACT_PATH)
        os.mkdir(EXTRACT_PATH)
    except:
        pass
    Archive(AOI_ZIP_PATH).extractall(EXTRACT_PATH)
    for path in glob.glob(''.join([EXTRACT_PATH, '/*.zip'])):
        Archive(path).extractall(EXTRACT_PATH)
        os.remove(path)

In [43]:
unzip()

In [44]:
RAW_DATA_PATH = '/root/aoi/train_images'
DATA_PATH = './data/sep_data'
TRAIN_CSV_PATH = '/root/aoi/train.csv'
CLASS = list(map(str, range(6)))
split_rate = 0.2

def split():
    for tv in ['train', 'validation']:
        for idx in CLASS:
            try:
                os.makedirs('/'.join([DATA_PATH, tv, idx]))
            except FileExistsError:
                shutil.rmtree(DATA_PATH)
                os.makedirs('/'.join([DATA_PATH, tv, idx]))
    with open(TRAIN_CSV_PATH, 'r', newline='') as csv:
        csv.readline()
        for line in csv.readlines():
            ID, label = line.strip().split(',')
            if random.random() >= split_rate:
                tv = 'train'
            else:
                tv = 'validation'
            shutil.copy('/'.join([RAW_DATA_PATH, ID]), 
                        '/'.join([DATA_PATH, tv, label, ID]))

In [47]:
split()

In [48]:
def count():
    print("Image counts:")
    for tv in ['train', 'validation']:
        print('   {}:'.format(tv))
        for idx in CLASS:
            abs_path = os.path.abspath('/'.join([DATA_PATH, tv, idx]))
            cnt = (bytearray(subprocess.run(['tree', abs_path], 
                                           stdout=subprocess.PIPE).stdout).decode().split(' ')[-2])
            print('{:>8}: {}'.format(idx, cnt))

In [49]:
count()

Image counts:
   train:
       0: 543
       1: 411
       2: 83
       3: 301
       4: 183
       5: 507
   validation:
       0: 131
       1: 81
       2: 17
       3: 77
       4: 57
       5: 137


In [50]:
def augmentation():
    tv = 'train'
    for idx in CLASS:
        img_paths = glob.glob('/'.join([DATA_PATH, tv, idx, '*.png']))
        imgs = []
        des_paths = []
        def f(i, des_paths):
            for index, img_path in enumerate(img_paths):
                img = Image.open(img_path)
                img_name = img_path.split('/')[-1].split('.')[-2]
                des_path = ''.join([img_name, '{}', '.png'])
                imgs.append(ImageOps.flip(img))
                des_paths.append(des_path.format('_flip_v_{}'.format(str(i))))
                imgs.append(ImageOps.mirror(img))
                des_paths.append(des_path.format('_flip_h_{}'.format(str(i))))
                angles = [180] if idx in ['2', '3'] else [90, 180, 270]
                for angle in angles:
                    imgs.append(img.rotate(angle))
                    des_paths.append(des_path.format('_{}_{}'.format(str(angle), i)))
            return des_paths
        
        if idx == '2':
            for i in [1, 2, 3, 4]:
                des_paths = f(i, des_paths)
        else:
            des_paths = f(0, des_paths)
        for img, des_path in zip(imgs, des_paths):
            des_path = '/'.join([DATA_PATH, tv, idx, des_path]) 
            img.save(des_path)

        
split()
augmentation()

In [51]:
count()

Image counts:
   train:
       0: 3222
       1: 2346
       2: 1001
       3: 1220
       4: 1140
       5: 3090
   validation:
       0: 137
       1: 101
       2: 23
       3: 73
       4: 50
       5: 129
