In [1]:
import tensorflow as tf
tf.enable_eager_execution()

import math
import multiprocessing
import glob

# Exploring Meta-dataset's TFRecords

In [2]:
RECORDS_DIR = "/Users/gomerudo/workspace/thesis_results/dtd"
pattern = "{dir}/*.tfrecords".format(dir=RECORDS_DIR)
tfrecords_list = sorted(glob.glob(pattern))
tfrecord_data = tf.data.TFRecordDataset(tfrecords_list)

In [48]:
features = {
    'image': tf.FixedLenFeature([], dtype=tf.string),
    'label': tf.FixedLenFeature([], tf.int64)
}

c = 0
for record in tfrecord_data:
    if c > 5:
        break
    parsed = tf.parse_single_example(record, features)

    # The image (features)
    image_decoded = tf.image.decode_jpeg(parsed['image'], channels=3)
    img = Image.fromarray(image_decoded.numpy(), 'RGB')
    img.show(title="algo")
    c+=1


In [45]:
from PIL import Image

def _n_elements(tfrecords_list):
    c = 0
    for fn in tfrecords_list:
        for _ in tf.python_io.tf_record_iterator(fn):
            c += 1
    return c

def _parser(record, image_size):
    # the 'features' here include your normal data feats along
    # with the label for that data
    features = {
        'image': tf.FixedLenFeature([], dtype=tf.string),
        'label': tf.FixedLenFeature([], tf.int64)
    }

    parsed = tf.parse_single_example(record, features)

    # The image (features)
    image_decoded = tf.image.decode_jpeg(parsed['image'], channels=3)
    image_resized = tf.image.resize_images(
        image_decoded,
        [image_size, image_size],
        method=tf.image.ResizeMethod.BILINEAR,
        align_corners=True
    )
    image_normalized = image_resized

    # The label
    label = tf.cast(parsed['label'], tf.int32)

    return {'x': image_normalized}, label


def _input_fn(tfrecord_data, length, batch_size=128, img_size=84):
    dataset = tfrecord_data 

    dataset = dataset.map(lambda record: _parser(record, img_size))
    dataset = dataset.batch(batch_size)
    # iterator = dataset.make_one_shot_iterator()

    return dataset


# 2. Get the dataset as tf.Dataset object
img_size = 84
dataset = _input_fn(
    tfrecord_data, _n_elements(tfrecords_list), img_size=img_size
)

# Iterate over all images. For each batch, we append to the csv file to
# avoid memory issues if the dataset is too big.
c = 0
for idx, batch in enumerate(dataset):
    print("Processing batch #{i}".format(i=idx+1))
    for img, label in zip(batch[0]['x'], batch[1]):
        if c > 4:
            break
        img_np = img.numpy()
        img = Image.fromarray(img_np, 'RGB')
        img.show()
        c += 1


Processing batch #1
Processing batch #2
Processing batch #3
Processing batch #4
Processing batch #5
Processing batch #6
Processing batch #7
Processing batch #8
Processing batch #9
Processing batch #10
Processing batch #11
Processing batch #12
Processing batch #13
Processing batch #14
Processing batch #15
Processing batch #16
Processing batch #17
Processing batch #18
Processing batch #19
Processing batch #20
Processing batch #21
Processing batch #22
Processing batch #23
Processing batch #24
Processing batch #25
Processing batch #26
Processing batch #27
Processing batch #28
Processing batch #29
Processing batch #30
Processing batch #31
Processing batch #32
Processing batch #33
Processing batch #34
Processing batch #35
Processing batch #36
Processing batch #37
Processing batch #38
Processing batch #39
Processing batch #40
Processing batch #41
Processing batch #42
Processing batch #43
Processing batch #44
Processing batch #45


## Playing with TensorFlow's TFRecords, Dataset and Estimator

In [102]:
image_size = 84

def n_elements(records_list):
    """Return the number of elements in a tensorflow records file."""
    count = 0
    for tfrecords_file in records_list:
        for _ in tf.python_io.tf_record_iterator(tfrecords_file):
            count += 1
    return count


def parser(record_dataset):
    """Parse a given TFRecordsDataset object."""
    # This is the definition we expect in the TFRecords for meta-dataset
    features = {
        'image': tf.FixedLenFeature([], dtype=tf.string),
        'label': tf.FixedLenFeature([], tf.int64)
    }
    exp_image_size = 84

    # 1. We parse the record_dataset with the features defined above.
    parsed = tf.parse_single_example(record_dataset, features)

    # 2. We will decode the image as a jpeg with 3 channels and resize it to
    #    the expected image size
    image_decoded = tf.image.decode_jpeg(parsed['image'], channels=3)
    image_resized = tf.image.resize_images(
        image_decoded,
        [exp_image_size, exp_image_size],
        method=tf.image.ResizeMethod.BILINEAR,
        align_corners=True
    )
    # 3. And we normalize the dataset in the range [0, 1]
    image_normalized = image_resized / 255.0

    # 4. we make the label an int32.
    label = tf.cast(parsed['label'], tf.int32)

    # 5. We return as dataset a s pair ( {features}, label)
    return {'x': image_normalized}, label


def metadataset_input_fn(tfrecord_data, data_length, batch_size=128,
                         is_train=True, split_prop=0.33, random_seed=32,
                         is_distributed=False):
    """Input function for a tensorflow estimator."""
    trainset_length = math.floor(data_length*(1. - split_prop))

    files = tf.data.Dataset.list_files(
        tfrecord_data, shuffle=False
    )
    n_threads = multiprocessing.cpu_count()
    print(
        "Number of threads available for dataset processing is %d", n_threads
    )
    dataset = files.apply(
        tf.contrib.data.parallel_interleave(
            lambda filename: tf.data.TFRecordDataset(filename),
            cycle_length=n_threads
        )
    )
    dataset = dataset.shuffle(data_length, seed=random_seed, reshuffle_each_iteration=False)

    if is_train:
        dataset = dataset.take(trainset_length)
        current_length = trainset_length
        dataset = dataset.apply(
            tf.contrib.data.shuffle_and_repeat(current_length, 10)
        )
    else:
        dataset = dataset.skip(trainset_length)
        current_length = data_length - trainset_length

    # shuffle and repeat examples for better randomness and allow training
    # beyond one epoch
#     count_repeat = 30 if is_train else 1

    print("Current length in input_fn %d", current_length)

    # map the parse function to each example individually in threads*2
    # parallel calls
    dataset = dataset.map(
        map_func=lambda example: parser(example),
        num_parallel_calls=n_threads
    )

    # batch the examples if using training, otherwise we want to evaluate on
    # the whole dataset in one single step
    dataset = dataset.batch(batch_size=batch_size)

    # prefetch batch
    dataset = dataset.prefetch(buffer_size=32)

    return dataset
#     if is_distributed:
#         return dataset

#     iterator = dataset.make_one_shot_iterator()
#     return iterator.get_next()


### Train set

In [104]:
import numpy as np

dataset_length = n_elements(tfrecords_list)
print("Dataset length is", dataset_length)

dataset = metadataset_input_fn(
    tfrecord_data=pattern,
    data_length=dataset_length,
    batch_size=128,
    is_train=True,
    split_prop=0.33,
    random_seed=32,
    is_distributed=False
)

exp_batches = math.ceil(dataset_length*(1-0.33)/128)
print("Expected number of batches per epoch:", exp_batches)

counts_list = []
classes_count = None
obs_set = set()
for i, batch in enumerate(dataset):
    if i % exp_batches == 0:
        print("Batch number", i+1)
        print("Starting over the dataset.")
        if classes_count is not None:
            counts_list.append(classes_count)
        classes_count = np.zeros(47, dtype=int)

    for img, label in zip(batch[0]['x'], batch[1]):
        img_np = img.numpy()
        obs_set.add(str(img_np.flatten()))
        label_np = int(label.numpy())
        classes_count[label_np] += 1

counts_list.append(classes_count)
for i, c_list in enumerate(counts_list):
    print("List number", i+1)
    print("Sum is", sum(c_list))
    print(c_list)

print("N different observations seen:", len(obs_set))

Dataset length is 5640
Number of threads available for dataset processing is %d 4
Current length in input_fn %d 3778
Expected number of batches per epoch: 30
Batch number 1
Starting over the dataset.
Batch number 31
Starting over the dataset.
Batch number 61
Starting over the dataset.
Batch number 91
Starting over the dataset.
Batch number 121
Starting over the dataset.
Batch number 151
Starting over the dataset.
Batch number 181
Starting over the dataset.
Batch number 211
Starting over the dataset.
Batch number 241
Starting over the dataset.
Batch number 271
Starting over the dataset.
List number 1
Sum is 3840
[74 89 83 81 89 86 83 86 80 84 73 82 80 87 80 70 85 72 76 85 82 90 90 80
 81 82 83 82 74 78 81 83 87 84 84 85 78 79 76 77 87 88 81 87 76 82 78]
List number 2
Sum is 3840
[70 93 80 84 91 88 82 90 82 82 74 79 85 87 82 71 84 71 75 83 81 87 88 82
 84 85 84 85 73 81 78 81 85 77 83 86 79 80 78 74 89 84 76 86 80 85 76]
List number 3
Sum is 3840
[76 85 85 76 94 86 79 83 77 82 76 81 77 8

The repeat makes the batches of equal size, meaning that if the last batch is not complete it will fill the missing values to obtain a complete batch (e.g. complete batch of 128 elements).

### Test set

In [105]:
import numpy as np

dataset_length = n_elements(tfrecords_list)
print("Dataset length is", dataset_length)

dataset = metadataset_input_fn(
    tfrecord_data=pattern,
    data_length=dataset_length,
    batch_size=128,
    is_train=False,
    split_prop=0.33,
    random_seed=32,
    is_distributed=False
)

exp_batches = math.ceil(dataset_length*(1-0.33)/128)
print("Expected number of batches per epoch:", exp_batches)

counts_list = []
classes_count = None
obs_set_test = set()
for i, batch in enumerate(dataset):
    if i % exp_batches == 0:
        print("Batch number", i+1)
        print("Starting over the dataset.")
        if classes_count is not None:
            counts_list.append(classes_count)
        classes_count = np.zeros(47, dtype=int)

    for img, label in zip(batch[0]['x'], batch[1]):
        img_np = img.numpy()
        obs_set_test.add(str(img_np.flatten()))
        label_np = int(label.numpy())
        classes_count[label_np] += 1

counts_list.append(classes_count)
for i, c_list in enumerate(counts_list):
    print("List number", i+1)
    print("Sum is", sum(c_list))
    print(c_list)

print("N different observations seen:", len(obs_set_test))

Dataset length is 5640
Number of threads available for dataset processing is %d 4
Current length in input_fn %d 1862
Expected number of batches per epoch: 30
Batch number 1
Starting over the dataset.
List number 1
Sum is 1862
[48 32 39 40 31 35 39 35 41 38 47 41 40 35 42 50 37 50 45 36 39 32 32 41
 39 38 37 39 47 42 41 39 34 40 37 36 44 42 44 45 34 34 43 34 45 39 44]
N different observations seen: 1828


In [106]:
print("Diff 1", len(obs_set.difference(obs_set_test)))
print("Diff 2", len(obs_set_test.difference(obs_set)))

Diff 1 3676
Diff 2 1811
