# 数据整理

In [2]:
import os
import shutil
import random
import numpy as np
import json
import matplotlib.pyplot as plt

random.seed(2018)
%matplotlib inline

### 将数据链接到'train\cat'，'train\dog'，'test'文件夹，用于生成特征向量：

In [28]:
with open('total_outliers.json', 'r') as f:
    outliers = json.load(f)
    print(len(outliers))

83


In [5]:
src_train_dir = '../dog_vs_cat_data/train'
src_test_dir = '../dog_vs_cat_data/test1'
src_crop_dir = 'crop_data'
des_train_dir = 'train'
des_test_dir = 'test'

cur_path = os.getcwd()

# ref train data
if os.path.exists(des_train_dir):
    shutil.rmtree(des_train_dir)

os.mkdir(des_train_dir)
des_train_cat_dir = des_train_dir + '/cat'
des_train_dog_dir = des_train_dir + '/dog'
os.mkdir(des_train_cat_dir)
os.mkdir(des_train_dog_dir)
train_data_list = os.listdir(src_train_dir)

for file_name in train_data_list:
    if file_name in outliers:
        continue
    if file_name.startswith('cat'):
        os.symlink(cur_path + '/' + src_train_dir + '/' + file_name, des_train_cat_dir + '/' + file_name)
    else:
        os.symlink(cur_path + '/' + src_train_dir + '/' + file_name, des_train_dog_dir + '/' + file_name)
        
# ref crop_data
crop_data_list = os.listdir(src_crop_dir)

for file_name in crop_data_list:
    if file_name.startswith('cat'):
        os.symlink(cur_path + '/' + src_crop_dir + '/' + file_name, des_train_cat_dir + '/' + file_name)
    else:
        os.symlink(cur_path + '/' + src_crop_dir + '/' + file_name, des_train_dog_dir + '/' + file_name)

# ref test data
if os.path.exists(des_test_dir):
    shutil.rmtree(des_test_dir)

os.mkdir(des_test_dir)
os.symlink(cur_path + '/' + src_test_dir, des_test_dir + '/test', target_is_directory=True)

del train_data_list
del crop_data_list

### 分割数据至"train_split"、"val_split"文件夹，用于fine-tune：

In [14]:
def random_sample_data(total_num: int, val_split=0.2):
    val_num = round(total_num * val_split)
    index_list = [i for i in range(total_num)]
    random.shuffle(index_list)
        
    return index_list[val_num:], index_list[:val_num]


In [27]:
src_train_dir = '../dog_vs_cat_data/train'
des_train_split_dir = 'train_split'
des_val_split_dir = 'val_split'
cur_path = os.getcwd()

cat_file_list = os.listdir('train/cat')
dog_file_list = os.listdir('train/dog')

# split
cat_train_index_list, cat_val_index_list = random_sample_data(len(cat_file_list))
dog_train_index_list, dog_val_index_list = random_sample_data(len(dog_file_list))

# ref data
if os.path.exists(des_train_split_dir):
    shutil.rmtree(des_train_split_dir)
if os.path.exists(des_val_split_dir):
    shutil.rmtree(des_val_split_dir)
    
os.mkdir(des_train_split_dir)
des_train_split_cat_dir = des_train_split_dir + '/cat'
des_train_split_dog_dir = des_train_split_dir + '/dog'
os.mkdir(des_train_split_cat_dir)
os.mkdir(des_train_split_dog_dir)
os.mkdir(des_val_split_dir)
des_val_split_cat_dir = des_val_split_dir + '/cat'
des_val_split_dog_dir = des_val_split_dir + '/dog'
os.mkdir(des_val_split_cat_dir)
os.mkdir(des_val_split_dog_dir)
    
# cat
for i, file in enumerate(cat_file_list):
    if i in cat_val_index_list:
        os.symlink(os.readlink(cur_path + '/train/cat/' + file), des_val_split_cat_dir + '/' + file)
    else:
        os.symlink(os.readlink(cur_path + '/train/cat/' + file), des_train_split_cat_dir + '/' + file)
    
# dog
for i, file in enumerate(dog_file_list):
    if i in dog_val_index_list:
        os.symlink(os.readlink(cur_path + '/train/dog/' + file), des_val_split_dog_dir + '/' + file)
    else:
        os.symlink(os.readlink(cur_path + '/train/dog/' + file), des_train_split_dog_dir + '/' + file)

del cat_file_list
del dog_file_list
del cat_train_index_list
del cat_val_index_list
del dog_train_index_list
del dog_val_index_list