In [41]:
import os
import random
import glob
import shutil
import subprocess

from tqdm import tqdm
from PIL import Image, ImageOps
from pyunpack import Archive

In [42]:
AOI_ZIP_PATH = '/root/aoi.zip'
EXTRACT_PATH = '/root/aoi'

def unzip():
    try:
        os.makedirs(EXTRACT_PATH)
    except FileExistsError:
        shutil.rmtree(EXTRACT_PATH)
        os.mkdir(EXTRACT_PATH)
    except:
        pass
    Archive(AOI_ZIP_PATH).extractall(EXTRACT_PATH)
    for path in glob.glob(''.join([EXTRACT_PATH, '/*.zip'])):
        Archive(path).extractall(EXTRACT_PATH)
        os.remove(path)

In [43]:
unzip()

In [57]:
RAW_DATA_PATH = '/root/aoi/train_images'
DATA_PATH = './data/sep_data'
DATA_PATH_2 = './data/sep2_data'
TRAIN_CSV_PATH = '/root/aoi/train.csv'
CLASS = list(map(str, range(6)))
split_rate = 0.2

def split():
    for tv in ['train', 'validation']:
        for idx in CLASS:
            try:
                os.makedirs('/'.join([DATA_PATH, tv, idx]))
            except FileExistsError:
                shutil.rmtree(DATA_PATH)
                os.makedirs('/'.join([DATA_PATH, tv, idx]))
        for idx in ['0', '1']:
            try:
                os.makedirs('/'.join([DATA_PATH_2, tv, idx]))
            except FileExistsError:
                shutil.rmtree(DATA_PATH_2)
                os.makedirs('/'.join([DATA_PATH_2, tv, idx]))
    with open(TRAIN_CSV_PATH, 'r', newline='') as csv:
        csv.readline()
        for line in csv.readlines():
            ID, label = line.strip().split(',')
            if random.random() >= split_rate:
                tv = 'train'
            else:
                tv = 'validation'
            shutil.copy('/'.join([RAW_DATA_PATH, ID]), 
                        '/'.join([DATA_PATH, tv, label, ID]))           

In [58]:
split()

In [59]:
def count():
    print("Image counts:")
    for tv in ['train', 'validation']:
        print('   {}:'.format(tv))
        for idx in CLASS:
            abs_path = os.path.abspath('/'.join([DATA_PATH, tv, idx]))
            cnt = (bytearray(subprocess.run(['tree', abs_path], 
                                           stdout=subprocess.PIPE).stdout).decode().split(' ')[-2])
            print('{:>8}: {}'.format(idx, cnt))
        print('   {}:'.format(tv))
        for idx in ['0', '1']:
            abs_path = os.path.abspath('/'.join([DATA_PATH_2, tv, idx]))
            cnt = (bytearray(subprocess.run(['tree', abs_path], 
                                           stdout=subprocess.PIPE).stdout).decode().split(' ')[-2])
            print('{:>8}: {}'.format(idx, cnt))  

In [60]:
count()

Image counts:
   train:
       0: 540
       1: 391
       2: 83
       3: 312
       4: 194
       5: 507
   train:
       0: 0
       1: 0
   validation:
       0: 134
       1: 101
       2: 17
       3: 66
       4: 46
       5: 137
   validation:
       0: 0
       1: 0


In [62]:
def augmentation():
    tv = 'train'
    for idx in CLASS:
        img_paths = glob.glob('/'.join([DATA_PATH, tv, idx, '*.png']))
        imgs = []
        des_paths = []
        def f(i, des_paths):
            for index in tqdm(range(len(img_paths))):
                img_path = img_paths[index]
                img = Image.open(img_path)
                img = img.resize((128, 128))
                img_name = img_path.split('/')[-1].split('.')[-2]
                imgs.append(img)
                des_path = ''.join([img_name, '{}', '.png'])
                des_paths.append(des_path.format(''))
                imgs.append(ImageOps.flip(img))
                des_paths.append(des_path.format('_flip_v_{}'.format(str(i))))
                imgs.append(ImageOps.mirror(img))
                des_paths.append(des_path.format('_flip_h_{}'.format(str(i))))
                angles = [180] if idx in ['2', '3'] else [90, 180, 270]
                for angle in angles:
                    imgs.append(img.rotate(angle))
                    des_paths.append(des_path.format('_{}_{}'.format(str(angle), i)))
            return des_paths
        
        if idx == '2':
            for i in [1, 2, 3, 4]:
                des_paths = f(i, des_paths)
        elif idx == '4':
            for i in [1, 2]:
                des_paths = f(i, des_paths)      
        else:
            des_paths = f(0, des_paths)
        for i in tqdm(range(len(imgs))):
            img = imgs[i]
            des_path = des_paths[i]
            if idx == '0':
                des_path = '/'.join([DATA_PATH_2, tv, idx, des_path]) 
            else:
                des_path = '/'.join([DATA_PATH_2, tv, '1', des_path]) 
            img.save(des_path)
    tv = 'validation'
    for idx in CLASS:
        img_paths = glob.glob('/'.join([DATA_PATH, tv, idx, '*.png'])) 
        imgs = []
        des_paths = []
        def f(i, des_paths):
            for index in tqdm(range(len(img_paths))):
                img_path = img_paths[index]
                img = Image.open(img_path)
                img = img.resize((128, 128))
                img_name = img_path.split('/')[-1].split('.')[-2]
                imgs.append(img)
                des_path = ''.join([img_name, '{}', '.png'])
                des_paths.append(des_path.format(''))
            return des_paths
        
        des_paths = f(0, des_paths)
        for i in tqdm(range(len(imgs))):
            img = imgs[i]
            des_path = des_paths[i]
            if idx == '0':
                des_path = '/'.join([DATA_PATH_2, tv, idx, des_path]) 
            else:
                des_path = '/'.join([DATA_PATH_2, tv, '1', des_path]) 
            img.save(des_path)
split()
augmentation()

100%|██████████| 550/550 [00:01<00:00, 316.12it/s]
100%|██████████| 3300/3300 [00:02<00:00, 1312.56it/s]
100%|██████████| 401/401 [00:01<00:00, 322.04it/s]
100%|██████████| 2406/2406 [00:01<00:00, 1344.58it/s]
100%|██████████| 81/81 [00:00<00:00, 313.58it/s]
100%|██████████| 81/81 [00:00<00:00, 317.87it/s]
100%|██████████| 81/81 [00:00<00:00, 319.11it/s]
100%|██████████| 81/81 [00:00<00:00, 319.66it/s]
100%|██████████| 1296/1296 [00:01<00:00, 1278.29it/s]
100%|██████████| 293/293 [00:00<00:00, 319.36it/s]
100%|██████████| 1172/1172 [00:00<00:00, 1428.37it/s]
100%|██████████| 190/190 [00:00<00:00, 322.00it/s]
100%|██████████| 190/190 [00:00<00:00, 321.07it/s]
100%|██████████| 2280/2280 [00:01<00:00, 1396.08it/s]
100%|██████████| 518/518 [00:01<00:00, 325.06it/s]
100%|██████████| 3108/3108 [00:03<00:00, 872.11it/s]
100%|██████████| 124/124 [00:00<00:00, 315.94it/s]
100%|██████████| 124/124 [00:00<00:00, 1360.07it/s]
100%|██████████| 91/91 [00:00<00:00, 332.85it/s]
100%|██████████| 91/91 

In [63]:
count()

Image counts:
   train:
       0: 550
       1: 401
       2: 81
       3: 293
       4: 190
       5: 518
   train:
       0: 3300
       1: 9829
   validation:
       0: 124
       1: 91
       2: 19
       3: 85
       4: 50
       5: 126
   validation:
       0: 124
       1: 371
