In [0]:
import tensorflow as tf
import os
import matplotlib.image as mpimg
import shutil
import random
import numpy as np
from skimage import io
from matplotlib import pyplot as plt
from google.colab import files
plt.rcParams["axes.grid"] = False

# Features when reading Examples
features = {
    'rows': tf.FixedLenFeature([], tf.int64),
    'cols': tf.FixedLenFeature([], tf.int64),
    'channels': tf.FixedLenFeature([], tf.int64),
    'image': tf.FixedLenFeature([], tf.string),
    'label': tf.FixedLenFeature([], tf.int64)
}

In [0]:
# optional if folder with images already in the machine
# otherwise upload the zip files containing all the images in a folder
uploaded = files.upload()
for fn in uploaded.keys():
  !unzip $fn

# GENERATION AND EXTRACTION CLASSES

In [0]:
def get_label_with_filename(filename):
    name = filename.split('/')[-1]
    label = name.split('_')
    label = label[1].split('.')
    return int(label[0])

def triplet_preparation(img_paths, block_size=10):
    img_paths.sort(key=lambda x: get_label_with_filename(x))
    
    blocks = [img_paths[i:i+block_size] for i in range(0, len(img_paths), block_size)]
    random.shuffle(blocks)
    
    return [b for bs in blocks for b in bs]

In [0]:
# CLASS TO GENERATE A TFRECORD FROM ALL THE IMAGES IN A FOLDER
class GenerateTFRecord:

    def convert_image_folder(self, img_folder, tfrecord_file_name, block_size=10):
        # Get all file names of images present in folder
        img_paths = os.listdir(img_folder)
        img_paths = [os.path.abspath(os.path.join(img_folder, i)) for i in img_paths]
        
        # If block_size = 0 then just sort, otherwise shuffle by blocks
        if block_size > 0: img_paths = triplet_preparation(img_paths, block_size)
        else: img_paths.sort(key=lambda x: get_label_with_filename(x))
        print("Image paths:", img_paths)

        with tf.python_io.TFRecordWriter(tfrecord_file_name) as writer:
            for img_path in img_paths:
                example = self._convert_image(img_path)
                writer.write(example.SerializeToString())
                
    def _convert_image(self, img_path):
        label = get_label_with_filename(img_path)
        img_shape = mpimg.imread(img_path).shape
        
#         filename = os.path.basename(img_path)

        # Read image data in terms of bytes
        with tf.gfile.GFile(img_path, 'rb') as fid:
            image_data = fid.read()
            
        # If image has more than 1 channels wirte the number of channels
        # otherwise write a 1 in the channels feature
        channels = img_shape[2] if len(img_shape)==3 else 1
        
        example = tf.train.Example(features = tf.train.Features(feature = {
#             'filename': tf.train.Feature(bytes_list = tf.train.BytesList(value = [filename.encode('utf-8')])),
            'rows': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[0]])),
            'cols': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[1]])),
            'channels': tf.train.Feature(int64_list = tf.train.Int64List(value = [channels])),
            'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_data])),
            'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label])),
        }))
        return example

In [0]:
# CLASS TO EXTRACT IMAGES FROM A TFRECORD AND RETURN IN A DATASET
class TFRecordExtractor:
    def __init__(self, tfrecord_file):
        self.tfrecord_file = os.path.abspath(tfrecord_file)

    def _extract_fn(self, tfrecord):
        # Extract the data record
        sample = tf.parse_single_example(tfrecord, features)

        # cast image [0, 255] to [0.0, 1.0]
        image = tf.image.decode_image(sample['image'], dtype=tf.uint8)
        image = tf.cast(image, tf.float32)
        image = image / 255
        
        print(image.dtype)
        img_shape = tf.stack([sample['rows'], sample['cols'], sample['channels']])
        label = sample['label']
        label = tf.cast(label, tf.int64)
        #filename = sample['filename']
        
        return image, label

    def extract_image(self):

        # Pipeline of dataset
        dataset = tf.data.TFRecordDataset([self.tfrecord_file])
        dataset = dataset.map(self._extract_fn)
        
        return dataset

# TFRECORD GENERATION AND TESTING

In [0]:
%%time
# GENERATE THE TFRecord for train and test sets
t = GenerateTFRecord()

train_name = 'ColonDB_Train'
test_name = 'ColonDB_Test'

train_block_size = 0
test_block_size = 0

!rm $train_name'.tfrecord'
!rm $test_name'.tfrecord'

t.convert_image_folder(train_name, train_name+'.tfrecord', block_size=train_block_size)
t.convert_image_folder(test_name, test_name+'.tfrecord', block_size=test_block_size)

In [0]:
# TEST THE READING AND PARSING OF A TFRECORD INTO A DATASET
t = TFRecordExtractor(train_name+'.tfrecord')
dataset = t.extract_image()

iterator = dataset.make_initializable_iterator()
el = iterator.get_next()
# print(el)

with tf.Session() as sess:
    sess.run(iterator.initializer)
    sample = sess.run(el)
    img = sample[0]
    label = sample[1]
    
    print(img.shape, label)
    img = np.reshape(img, newshape= (img.shape[0], img.shape[1], img.shape[2]) )
    print(img)
    
    f, axarr = plt.subplots(ncols=1, nrows=1, figsize=(20, 5))
    axarr.imshow(img)
    axarr.set_title('Example image')
    plt.show()

In [0]:
t = TFRecordExtractor(train_name+'.tfrecord')
dataset = t.extract_image()

iterator = dataset.make_initializable_iterator()
el = iterator.get_next()
# print(el)

count = 0
with tf.Session() as sess:
    sess.run(iterator.initializer)
    while True:
        try:
            sample = sess.run(el)
            img = sample[0]
            label = sample[1]
            print(label)
            count += 1
        except tf.errors.OutOfRangeError:
            break
print("Total number of images:", count)

In [0]:
%%time
# GENERATE THE TFRecord for TESTING
t = GenerateTFRecord()

testing = 'F1_train_TEST'

t.convert_image_folder(train_name, testing+'.tfrecord', block_size=10)


In [0]:
t = TFRecordExtractor(testing+'.tfrecord')
dataset = t.extract_image()

iterator = dataset.make_initializable_iterator()
el = iterator.get_next()
# print(el)

count = 0
with tf.Session() as sess:
    sess.run(iterator.initializer)
    while True:
        try:
            sample = sess.run(el)
            img = sample[0]
            label = sample[1]
            print(label)
            count += 1
        except tf.errors.OutOfRangeError:
            break
print("Total number of images:", count)