In [1]:
import os
import tarfile
import _pickle as cPickle
import numpy as np
import urllib.request
import scipy.misc

In [2]:
cifar_link = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
data_dir = 'temp'
if not os.path.isdir(data_dir):
    os.makedirs(data_dir)
    
objects = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [3]:
# tarファイルをダウンロード
target_file = os.path.join(data_dir, 'cifar-10-python.tar.gz')
print(target_file)
if not os.path.isfile(target_file):
    print('CIFAR-10 file not found. Downloading CIFAR data (Size = 163MB)')
    print('This may take a few minutes, please wait.')
    filename, headers = urllib.request.urlretrieve(cifar_link, target_file)
    
# データをメモリに展開
tar = tarfile.open(target_file)
tar.extractall(path=data_dir)
tar.close()

temp\cifar-10-python.tar.gz
CIFAR-10 file not found. Downloading CIFAR data (Size = 163MB)
This may take a few minutes, please wait.


In [6]:
# トレーニング画像用のフォルダを作成
train_folder = 'train_dir'
if not os.path.isdir(os.path.join(data_dir, train_folder)):
    for i in range(10):
        folder = os.path.join(data_dir, train_folder, objects[i])
        os.makedirs(folder)
        
# テスト画像用のフォルダを作成
test_folder = 'validation_dir'
if not os.path.isdir(os.path.join(data_dir, test_folder)):
    for i in range(10):
        folder = os.path.join(data_dir, test_folder, objects[i])
        os.makedirs(folder)

In [7]:
def load_batch_from_file(file):
    file_conn = open(file, 'rb')
    image_dictionary = cPickle.load(file_conn, encoding='latin1')
    file_conn.close()
    return(image_dictionary)

In [11]:
def save_images_from_dict(image_dict, folder='data_dir'):
    for ix, label in enumerate(image_dict['labels']):
        folder_path = os.path.join(data_dir, folder, objects[label])
        filename = image_dict['filenames'][ix]
        # 画像データを変換
        image_array = image_dict['data'][ix]
        image_array.resize([3, 32, 32])
        # 画像を保存
        output_location = os.path.join(folder_path, filename)
        scipy.misc.imsave(output_location, image_array.transpose())

In [12]:
# 画像を抽出するためのパラメータ
data_location = os.path.join(data_dir, 'cifar-10-batches-py')
train_names = ['data_batch_' + str(x) for x in range(1, 6)]
test_names = ['test_batch']

# トレーニング画像を振り分け
for file in train_names:
    print('Saving images from file: {}'.format(file))
    file_location = os.path.join(data_dir, 'cifar-10-batches-py', file)
    image_dict = load_batch_from_file(file_location)
    save_images_from_dict(image_dict, folder=train_folder)
    
# テスト画像を振り分け
for file in test_names:
    print('Saving images from file: {}'.format(file))
    file_location = os.path.join(data_dir, 'cifar-10-batches-py', file)
    image_dict = load_batch_from_file(file_location)
    save_images_from_dict(image_dict, folder=test_folder)

Saving images from file: data_batch_1
Saving images from file: data_batch_2
Saving images from file: data_batch_3
Saving images from file: data_batch_4
Saving images from file: data_batch_5
Saving images from file: test_batch


In [13]:
# ラベルファイルを作成
cifar_labels_file = os.path.join(data_dir, 'cifar10_labels.txt')
print('Writing labels file, {}'.format(cifar_labels_file))
with open(cifar_labels_file, 'w') as labels_file:
    for item in objects:
        labels_file.write("{}\n".format(item))

Writing labels file, temp\cifar10_labels.txt
