# Fine-grained TFRecords

* tfrecords 를 처음에 하나로 만들었는데 그러면 안 됨.

## Experiment

* 클래스별로 하나씩 만들어보고, `batch_shuffle` 과 `batch_shuffle_join` 의 차이점을 확인해보자.
* 클래스별로 하나씩 만들어도 충분한 지 확인하자.

## Retry!

* http://coolingoff.tistory.com/23
* 이걸 참고해서 해 보자.
* 그냥 reading data 번역인 거 같긴 한데... 그것도 참고하고..

# Concolusion

Setting: 각 클래스별 tfrecords 파일 하나

* train data
    * 트레이닝 데이터는 랜덤하게 셔플해서 받을 수 있어야 함
    * 그러나 shuffle_batch 는 tfrecords 파일 하나에서 받아오므로 제대로 셔플이 안 됨 (한 클래스에 대해서 읽어온 후 셔플)
    * 따라서 shuffle_batch_join 을 써야 하고 batch_join 을 써도 다양한 클래스에서 읽어오기는 함
        * 이렇게 되면 read_thread 수를 클래스 수 이상으로 해 줘야 할 듯
        * ImageNet 같이 데이터/클래스 전부 엄청 많으면 어떻게 해야 하지?
* test data
    * 테스트 데이터는 5개의 tfrecords 를 통째로 읽어와야 함
    * 생각해보면 num_epoch 을 사용하면 컨트롤 할 수 있을 것 같은데?
    * => 실험결과 된다!
    
## 결론 of 결론
    
**`shuffle_batch_join` + num_epochs 써라!**

## Prepare data

In [1]:
%matplotlib inline
import tensorflow as tf
import scipy
import matplotlib.pyplot as plt
slim = tf.contrib.slim

  from ._conv import register_converters as _register_converters


In [2]:
import os, sys, glob, shutil
import urllib
import tarfile
import numpy as np
from scipy.io import loadmat
import time

In [3]:
def download_file(url, dest=None):
    if not dest:
        dest = 'data/' + url.split('/')[-1]
    if sys.version_info[0] == 3:
        urllib.request.urlretrieve(url, dest)
    else:
        urllib.urlretrieve(url, dest)

### Download TF Flower dataset

In [4]:
LABELS = ["daisy", "dandelion", "roses", "sunflowers", "tulips"]
url = "http://download.tensorflow.org/example_images/flower_photos.tgz"

In [5]:
if not os.path.exists("data/flower_photos"):
    if not os.path.exists("data"):
        os.mkdir("data")
    print("Download flower dataset..")
    download_file(url)
    print("Extracting dataset..")
    tarfile.open("data/flower_photos.tgz", "r:gz").extractall(path="data/")
#     os.remove("data/flower_photos.tgz") # 굳이...

### Split dataset into train/test

In [6]:
train_ratio = 0.9
remake = False
parent_dir = "data/flower_photos"
train_dir = os.path.join(parent_dir, "train")
test_dir = os.path.join(parent_dir, "test")

if not os.path.exists(train_dir) or not os.path.exists(test_dir) or remake:
    # make dirs
    for label in LABELS:
        # tf.gfile.MakeDirs make dir recursively & ignore exist dir
        tf.gfile.MakeDirs(os.path.join(train_dir, label))
        tf.gfile.MakeDirs(os.path.join(test_dir, label))

    # copy files
    for i, label in enumerate(LABELS):
        dir_name = os.path.join(parent_dir, label)
        paths = glob.glob(dir_name + "/*.jpg")
        num_examples = len(paths)
        for j, path in enumerate(paths):
            fn = os.path.basename(path)
            is_train = j < (num_examples * train_ratio)

            if is_train:
                to_path = os.path.join(train_dir, label, fn)
            else:
                to_path = os.path.join(test_dir, label, fn)
            
            tf.gfile.Copy(path, to_path)

In [7]:
!find ./data/flower_photos/test ./data/flower_photos/train -type f | cut -d/ -f4 | uniq -c

    364 test
   3306 train


### Convert to `TFRecords` format

In [8]:
def _bytes_features(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))


def _int64_features(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=value))

In [9]:
def dir_to_tfrecords(dir_path, class_idx, tfrecords_path):
    '''convert image-containing dir to tfrecords without exist check.
    return: # of image files
    '''
    with tf.python_io.TFRecordWriter(tfrecords_path) as writer:
        paths = glob.glob(dir_path + "/*.jpg")
        num_examples = len(paths)
        for path in paths:
            im = scipy.misc.imread(path)
            im = scipy.misc.imresize(im, [64, 64])
            im_raw = im.tostring()
            features = {
                "shape": _int64_features(im.shape),
                "image": _bytes_features([im_raw]),
                "label": _int64_features([class_idx])
            }

            example = tf.train.Example(features=tf.train.Features(feature=features))
            writer.write(example.SerializeToString())
            
        return num_examples

In [10]:
print("Convert dataset to fine-grained TFRecord files ..")

tfrecords_format = "data/flower_photos_{}_{}.tfrecords"

num_train = 0
num_test = 0

remake_tfrecords = False
tfrecords_path_list = [tfrecords_format.format(top, label) for top in ['train' , 'test'] for label in LABELS]

if all(map(tf.gfile.Exists, tfrecords_path_list)) and remake_tfrecords == False:
    # already exists
    num_train = 3306
    num_test = 364
else:
    # make tfrecords files
#     num_train = dir_to_tfrecords('data/flower_photos/train/', tfrecords_train_fn)
#     num_test = dir_to_tfrecords('data/flower_photos/test/', tfrecords_test_fn)
    num = {'train': 0, 'test': 0}
    for top in ['train', 'test']:
        for i, label in enumerate(LABELS):
            dir_path = os.path.join('data/flower_photos/', top, label)
            tfrecords_path = tfrecords_format.format(top, label)
            num_cur = dir_to_tfrecords(dir_path, i, tfrecords_path)
            num[top] += num_cur
            print('# of {}/{}: {}'.format(top, label, num_cur))
    num_train = num['train']
    num_test = num['test']

    
# how to get num_examples from tfrecords file?
print(num_train, num_test)

Convert dataset to fine-grained TFRecord files ..
3306 364


## Read data from TFRecords

In [11]:
summary_root_dir = './summary/fine-grained/'
summary_train_dir = os.path.join(summary_root_dir, 'train')
summary_test_dir = os.path.join(summary_root_dir, 'test')
model_name = 'tfrecords-fine-grained'
tfrecords_path_list_train = [path for path in tfrecords_path_list if 'train' in path]
tfrecords_path_list_test = [path for path in tfrecords_path_list if 'test' in path]

In [12]:
tfrecords_path_list_train

['data/flower_photos_train_daisy.tfrecords',
 'data/flower_photos_train_dandelion.tfrecords',
 'data/flower_photos_train_roses.tfrecords',
 'data/flower_photos_train_sunflowers.tfrecords',
 'data/flower_photos_train_tulips.tfrecords']

In [13]:
tfrecords_path_list_test

['data/flower_photos_test_daisy.tfrecords',
 'data/flower_photos_test_dandelion.tfrecords',
 'data/flower_photos_test_roses.tfrecords',
 'data/flower_photos_test_sunflowers.tfrecords',
 'data/flower_photos_test_tulips.tfrecords']

# Check our batch!

In [14]:
def read_data(filename_queue):
    with tf.variable_scope('read_data'):
        reader = tf.TFRecordReader()
        key, records = reader.read(filename_queue)
        
        # parse records
        features = tf.parse_single_example(
            records,
            features={
                "shape": tf.FixedLenFeature([3], tf.int64),
                "image": tf.FixedLenFeature([], tf.string),
                "label": tf.FixedLenFeature([], tf.int64)
            }
        )

        image = tf.decode_raw(features["image"], tf.uint8)
        shape = tf.cast(features["shape"], tf.int32)
        label = tf.cast(features["label"], tf.int32)

        # preproc
        image = tf.reshape(image, [64, 64, 3])
#         image = tf.image.resize_images(images=image, size=[64, 64])
        image = tf.cast(image, tf.float32)
        image = image / 255.0

        one_hot_label = tf.one_hot(label, depth=5)
        
        return image, one_hot_label

In [15]:
# https://www.tensorflow.org/programmers_guide/reading_data

def get_batch_join(tfrecords_path_list, batch_size, shuffle=False, 
                   read_thread=5, min_after_dequeue=500, num_epochs=None):
    with tf.variable_scope("get_batch_join"):
        # make input pipeline
        filename_queue = tf.train.string_input_producer(tfrecords_path_list, shuffle=shuffle, num_epochs=num_epochs)
        # 5개의 reader 를 두고 각각 읽어오게 하자
        example_list = [read_data(filename_queue) for _ in range(read_thread)]
        
        # train case (shuffle)
        capacity = min_after_dequeue + 3*batch_size
        if shuffle:
            images, labels = tf.train.shuffle_batch_join(tensors_list=example_list, batch_size=batch_size,
                                                         capacity=capacity, min_after_dequeue=min_after_dequeue,
                                                         allow_smaller_final_batch=True)
        else:
            images, labels = tf.train.batch_join(example_list, batch_size, capacity=capacity, 
                                                 allow_smaller_final_batch=True)
            
        return images, labels

In [16]:
def get_batch(tfrecords_path_list, batch_size, shuffle=False, 
              read_thread=5, min_after_dequeue=500, num_epochs=None):
    with tf.variable_scope("get_batch"):
        filename_queue = tf.train.string_input_producer(tfrecords_path_list, shuffle=shuffle, num_epochs=num_epochs)
        image, label = read_data(filename_queue)
        
        capacity = min_after_dequeue + 3*batch_size
        if shuffle:
            images, labels = tf.train.shuffle_batch([image, label], batch_size=batch_size, capacity=capacity, 
                                                    min_after_dequeue=min_after_dequeue, num_threads=read_thread,
                                                    allow_smaller_final_batch=True)
        else:
            images, labels = tf.train.batch([image, label], batch_size, capacity=capacity, num_threads=read_thread,
                                            allow_smaller_final_batch=True)
        
        return images, labels

In [17]:
num_train, num_test

(3306, 364)

In [18]:
tf.reset_default_graph()
# default min_after_dequeue = 500
X, y = get_batch(tfrecords_path_list_train, batch_size=128, shuffle=True, num_epochs=None)

# sess = tf.Session()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    cur_X, cur_y = sess.run([X, y])

    coord.request_stop()
    coord.join(threads)

In [19]:
np.sum(cur_y, axis=0)

array([  2.,   0., 126.,   0.,   0.], dtype=float32)

## Test for all cases

`batch`, `batch_join`, `shuffle_batch`, `shuffle_batch_join` 4개의 케이스에 대해 테스트해본다.

In [20]:
print(num_train, num_test)
for data_type in ['train', 'test']:
    n_examples = num_train if data_type == 'train' else num_test
    kargs = {
        'min_after_dequeue': 500,
        'num_epochs': 5,
        'read_thread': 20,
        'batch_size': 512
    }
    if data_type == 'train':
        kargs['tfrecords_path_list'] = tfrecords_path_list_train
    else:
        kargs['tfrecords_path_list'] = tfrecords_path_list_test
    
    print("===== {} =====".format(data_type))
    for func_type in ['batch', 'batch_join', 'shuffle_batch', 'shuffle_batch_join']:
        tf.reset_default_graph()
        kargs['shuffle'] = 'shuffle' in func_type

        if 'join' in func_type:
            X, y = get_batch_join(**kargs)
        else:
            X, y = get_batch(**kargs)

        print("[{}]".format(func_type))

        # sess = tf.Session()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            # epochs 를 쓰면 local_variables_init 도 해줘야 함.
            # 위에서 지정한 num_epochs 가 local_variable 로 그래프에 박히는 듯
            sess.run(tf.local_variables_initializer())

            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            
            st = time.time()

            # num_epochs 를 지정해주면 이러한 방식으로 해야 함
            n_iter = None
            step_cnt = 0
            data_cnt = np.zeros([5])
            try:
                while not coord.should_stop():
                    cur_X, cur_y = sess.run([X, y])
                    print(np.sum(cur_y, axis=0), end=" ")
                    data_cnt += np.sum(cur_y, axis=0)
                    step_cnt += 1
                    epoch = np.ceil(float(step_cnt * kargs['batch_size']) / n_examples)
                    print(int(epoch))
            except tf.errors.OutOfRangeError:
                print('Done -- epoch limit reached')
                print(data_cnt, np.sum(data_cnt, dtype=np.int32))
                print('elapsed time: {:.2f}s'.format(time.time() - st))
            finally:
                coord.request_stop()

            print("")
#             coord.request_stop()
            coord.join(threads)

3306 364
===== train =====
[batch]
[512.   0.   0.   0.   0.] 1
[ 58. 454.   0.   0.   0.] 1
[  0. 355. 157.   0.   0.] 1
[  0.   0. 420.  92.   0.] 1
[  0.   0.   0. 512.   0.] 1
[  0.   0.   0.  26. 486.] 1
[278.   0.   0.   0. 234.] 2
[292. 220.   0.   0.   0.] 2
[  0. 512.   0.   0.   0.] 2
[  0.  77. 435.   0.   0.] 2
[  0.   0. 142. 370.   0.] 2
[  0.   0.   0. 260. 252.] 2
[ 44.   0.   0.   0. 468.] 3
[512.   0.   0.   0.   0.] 3
[ 14. 498.   0.   0.   0.] 3
[  0. 311. 201.   0.   0.] 3
[  0.   0. 376. 136.   0.] 3
[  0.   0.   0. 494.  18.] 3
[  0.   0.   0.   0. 512.] 3
[322.   0.   0.   0. 190.] 4
[248. 264.   0.   0.   0.] 4
[  0. 512.   0.   0.   0.] 4
[  0.  33. 479.   0.   0.] 4
[  0.   0.  98. 414.   0.] 4
[  0.   0.   0. 216. 296.] 4
[ 88.   0.   0.   0. 424.] 5
[482.  30.   0.   0.   0.] 5
[  0. 512.   0.   0.   0.] 5
[  0. 267. 245.   0.   0.] 5
[  0.   0. 332. 180.   0.] 5
[  0.   0.   0. 450.  62.] 5
[  0.   0.   0.   0. 512.] 5
[  0.   0.   0.   0. 146.] 6
Done -- 