# 数据整理

In [1]:
import os
import shutil
import random
import numpy as np
import json
import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
def clear_dir(path: str):
    if os.path.exists(path):
        shutil.rmtree(path)
        
    os.mkdir(path)
    

In [3]:
outliers = []
if os.path.exists('outliers_total.json'):
    with open('outliers_total.json', 'r') as f:
        outliers = json.load(f)
        print(len(outliers))

116


### 创建原始数据链接
* 将数据链接到'train\cat'，'train\dog'，'test\test'文件夹，用于flow_from_directory；
* 去掉outliers；
* 从'crop_data'目录链接裁剪数据。

In [5]:
def ref_test_data():
    src_dir = '../dog_vs_cat_data/test1'
    des_dir = 'test'

    cur_path = os.getcwd()
    
    clear_dir(des_dir)
    os.symlink(cur_path + os.sep + src_dir, des_dir + os.sep + 'test', target_is_directory=True)

    return


In [50]:
def ref_train_data(class_name: str, outlier_list: list):
    src_dir = '../dog_vs_cat_data/train'
    src_crop_dir = 'crop_data'
    des_dir = 'train'

    cur_path = os.getcwd()

    # clear dir
    if not os.path.exists(des_dir):
        os.mkdir(des_dir)
        
    des_dir = des_dir + os.sep + class_name
    clear_dir(des_dir)
    
    # ref raw data
    data_list = os.listdir(src_dir)

    for file_name in data_list:
        if file_name in outlier_list:
            continue
        elif file_name.startswith(class_name):
            os.symlink(cur_path + os.sep + src_dir + os.sep + file_name, des_dir + os.sep + file_name)

    # ref crop_data
    crop_data_list = []
    if os.path.exists(src_crop_dir):
        crop_data_list = os.listdir(src_crop_dir)

    for file_name in crop_data_list:
        if file_name.startswith(class_name):
            des_file = des_dir + os.sep + file_name
            if os.path.exists(des_file):
                print("crop file exists!")
            else:
                os.symlink(cur_path + os.sep + src_crop_dir + os.sep + file_name, des_file)
    
    return


In [51]:
ref_train_data('cat', outliers)
ref_train_data('dog', outliers)
ref_test_data()


### 分割数据
* 随机抽取20%数据至"val"文件夹
* 剩余80%训练数据按1:9分为"pretrain"、"finetune"两部分

In [4]:
def random_sample_data(total_num: int, val_split=0.2, pretrain_split=0.1, seed=2018):
    val_num = round(total_num * val_split)
    train_num = total_num - val_num
    pretrain_num = round(pretrain_split * train_num)
    
    index_list = [i for i in range(total_num)]
    random.seed(seed)
    random.shuffle(index_list)
    
    return index_list[:pretrain_num], index_list[pretrain_num:train_num], index_list[train_num:]


In [5]:
def split_train_data(class_name: str, val_split=0.2, pretrain_split=0.1):
    src_dir = 'train' + os.sep + class_name + os.sep
    pretrain_dir = 'pretrain'
    finetune_dir = 'finetune'
    val_dir = 'val'
    cur_path = os.getcwd()

    file_list = os.listdir('train' + os.sep + class_name)

    # data split
    pretrain_index_list, finetune_index_list, val_index_list = \
    random_sample_data(len(file_list), val_split=val_split, pretrain_split=pretrain_split)

    # clear pretrain dir
    if not os.path.exists(pretrain_dir):
        os.mkdir(pretrain_dir)
        
    pretrain_dir = pretrain_dir + os.sep + class_name
    clear_dir(pretrain_dir)
    
    # clear finetune dir
    if not os.path.exists(finetune_dir):
        os.mkdir(finetune_dir)
        
    finetune_dir = finetune_dir + os.sep + class_name
    clear_dir(finetune_dir)
    
    # clear val dir
    if not os.path.exists(val_dir):
        os.mkdir(val_dir)
        
    val_dir = val_dir + os.sep + class_name
    clear_dir(val_dir)
    
    # ref data
    for i, file in enumerate(file_list):
        if i in pretrain_index_list:
            file_path = pretrain_dir + os.sep + file
        elif i in finetune_index_list:
            file_path = finetune_dir + os.sep + file
        else:
            file_path = val_dir + os.sep + file
        
        os.symlink(os.readlink(src_dir + file), file_path)

    return


In [6]:
split_train_data('cat')
split_train_data('dog')
        

### Verify:

In [7]:
def verify_data(class_name: str, outliers: list):
    train_list = os.listdir('train' + os.sep + class_name)
    crop_list = os.listdir('crop_data')
    for file in crop_list:
        if file.startswith(class_name):
            train_list.remove(file)
    
    for file in outliers:
        if file.startswith(class_name):
            train_list.append(file)
    
    assert len(train_list) == len(set(train_list))
    print(len(train_list))
    

In [8]:
verify_data('dog', outliers)
verify_data('cat', outliers)

12500
12500


In [9]:
def verify_data_split(class_name: str):
    split_file_list=['pretrain', 'finetune', 'val']
    total_file_list = []
    for file_list in split_file_list:
        total_file_list = total_file_list + os.listdir(file_list + os.sep + class_name)
    
    print(len(total_file_list))
    assert len(set(total_file_list)) == len(total_file_list)
    train_list = os.listdir('train' + os.sep + class_name)
    assert len(total_file_list) == len(train_list)
    

In [10]:
verify_data_split('dog')
verify_data_split('cat')

12465
12430
