In [None]:
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import tensorflow as tf

tf.enable_eager_execution()

from skimage import io, img_as_float, img_as_ubyte
from skimage.transform import rescale, resize
from skimage.util import view_as_blocks

# Features when reading Examples
features = {
    'rows': tf.FixedLenFeature([], tf.int64),
    'cols': tf.FixedLenFeature([], tf.int64),
    'channels': tf.FixedLenFeature([], tf.int64),
    'image': tf.FixedLenFeature([], tf.string),
    'label': tf.FixedLenFeature([], tf.int64)
}

%load_ext autotime

In [None]:
def get_label_with_filename(filename):
    name = filename.split('/')[-1]
    label = name.split('.')[0]
    class_id, img_count = label.split('_')
    return int(img_count)

def get_patch_count(filename):
    name = filename.split('/')[-1]
    label = name.split('.')[0]
    class_id, img_count = label.split('_')
    return int(img_count)

def get_count(filename):
    name = filename.split('/')[-1]
    count = name.split('.')[0]
    return int(count)

def get_anomaly_count(filename):
    name = filename.split('/')[-1]
    label = name.split('.')[0]
    img_count = label.split('_')[-1]
    return int(img_count)

In [None]:
# CLASS TO GENERATE A TFRECORD FROM ALL THE IMAGES IN A FOLDER
class GenerateTFRecord:

    def convert_image_folder(self, img_folder, tfrecord_file_name):
        # Get all file names of images present in folder
        img_paths = os.listdir(img_folder)
        img_paths = [os.path.abspath(os.path.join(img_folder, i)) for i in img_paths]
        
        img_paths.sort(key=lambda x: get_anomaly_count(x))

        with tf.python_io.TFRecordWriter(tfrecord_file_name) as writer:
            for img_path in img_paths:
                example = self._convert_image(img_path)
                writer.write(example.SerializeToString())
                
    def _convert_image(self, img_path):
        label = get_anomaly_count(img_path)
        img_shape = mpimg.imread(img_path).shape
        
#         filename = os.path.basename(img_path)

        # Read image data in terms of bytes
        with tf.gfile.GFile(img_path, 'rb') as fid:
            image_data = fid.read()
            
        # If image has more than 1 channels wirte the number of channels
        # otherwise write a 1 in the channels feature
        channels = img_shape[2] if len(img_shape)==3 else 1
        
        example = tf.train.Example(features = tf.train.Features(feature = {
#             'filename': tf.train.Feature(bytes_list = tf.train.BytesList(value = [filename.encode('utf-8')])),
            'rows': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[0]])),
            'cols': tf.train.Feature(int64_list = tf.train.Int64List(value = [img_shape[1]])),
            'channels': tf.train.Feature(int64_list = tf.train.Int64List(value = [channels])),
            'image': tf.train.Feature(bytes_list = tf.train.BytesList(value = [image_data])),
            'label': tf.train.Feature(int64_list = tf.train.Int64List(value = [label])),
        }))
        return example
    
    
# CLASS TO EXTRACT IMAGES FROM A TFRECORD AND RETURN A DATASET
class TFRecordExtractor:
    def __init__(self, tfrecord_file):
        self.tfrecord_file = os.path.abspath(tfrecord_file)

    def _extract_fn(self, tfrecord):
        # Extract the data record
        sample = tf.parse_single_example(tfrecord, features)

        # cast image [0, 255] to [0.0, 1.0]
        image = tf.image.decode_image(sample['image'], dtype=tf.uint8)
        image = tf.cast(image, tf.float32)
        image = image / 255
        
        #print(image.dtype)
        img_shape = tf.stack([sample['rows'], sample['cols'], sample['channels']])
        label = sample['label']
        label = tf.cast(label, tf.int64)
        #filename = sample['filename']
        
        return image

    def extract_image(self):

        # Pipeline of dataset
        dataset = tf.data.TFRecordDataset([self.tfrecord_file])
        dataset = dataset.map(self._extract_fn)
        
        return dataset

In [None]:
# CREATE A TFRECORD FROM AN IMAGES FOLDER
t = GenerateTFRecord()

datasets_path = '../Datasets/'
tfrecords_path = '../TFRecords/'

if not os.path.exists(tfrecords_path): os.mkdir(tfrecords_path)

#train_name = 'Bark-dataset-Train'
test_name = 'Bark-Clear-Quantitative'

#t.convert_image_folder(datasets_path+train_name, tfrecords_path+train_name+'.tfrecord')
t.convert_image_folder(datasets_path+test_name, tfrecords_path+test_name+'.tfrecord')

In [None]:
image_size = 256
input_size = 128
crop_size = 32

crop_ratio = 0.5
def random_crop(image):
    crop_size = tf.random_uniform(shape=[], minval=150, maxval=200, dtype=tf.int32)
    image = tf.image.random_crop(image, size=(crop_size,crop_size,3))
    image = tf.image.resize(image, size=(image_size,image_size))
    return image

def map_test(image, label):
    do_crop = tf.random_uniform(shape=[], minval=0., maxval=1., dtype=tf.float32)
    image = tf.cond(tf.less(do_crop, crop_ratio), lambda: random_crop(image), lambda: image)

    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.random_brightness(image, max_delta=0.25)
    image = tf.image.random_saturation(image, lower=0.75, upper=1.25)
    image = tf.image.random_contrast(image, lower=0.75, upper=1.25)
    image = tf.image.random_hue(image, max_delta=0.15)
    return image, label

def get_patches(image, label):
    ksizes=[1,input_size,input_size,1]
    strides=[1,crop_size,crop_size,1]
    rates=[1,1,1,1]
    patches = tf.image.extract_image_patches(tf.reshape(image, (1, 256, 256, 3)),ksizes,strides,rates,padding='VALID')
    num_patches = int( (image_size/input_size)**2 )
    return tf.reshape(patches, (-1, input_size, input_size, 3))

def reduced_train_input_fn(dataset, samples=100):
    dataset = dataset.map(get_patches, num_parallel_calls=tf.data.experimental.AUTOTUNE)#.apply(tf.data.experimental.unbatch())
    dataset = dataset.take(samples)
    #dataset = dataset.batch(8)
    dataset = dataset.prefetch(1)  # make sure you always have one batch ready to serve
    return dataset

In [None]:
# TEST THE READING AND PARSING OF THE TRAINING TFRECORD INTO A DATASET
t = TFRecordExtractor(tfrecords_path+test_name+'.tfrecord')
dataset = t.extract_image().batch(8)
#dataset = reduced_train_input_fn(dataset)
#dataset = dataset.map(map_test, num_parallel_calls=tf.data.experimental.AUTOTUNE)
#dataset = dataset.map(get_patches, num_parallel_calls=tf.data.experimental.AUTOTUNE)

patch_count = 0
i, j = 0, 0
for batch in dataset:
    print(batch.shape)
    
    for image in batch:
        print(image.shape)
        
        if i+j == 0:
            f, axarr = plt.subplots(ncols=2, nrows=2, figsize=(15, 7))

        axarr[i,j].imshow(image)
        axarr[i,j].set_title(str(patch_count))

        if i==1 and j==1:
            plt.show()
            i, j = 0, 0
        else:
            if j==1:
                i += 1
                j = 0
            else:
                j += 1
        patch_count += 1

In [None]:
# TEST THE READING AND PARSING OF THE TRAINING TFRECORD INTO A DATASET
t = TFRecordExtractor(tfrecords_path+test_name+'.tfrecord')
dataset = t.extract_image()
#dataset = dataset.map(map_test, num_parallel_calls=tf.data.experimental.AUTOTUNE)

patch_count = 0
i, j = 0, 0
for sample in dataset.take(10):
    if i+j == 0:
        f, axarr = plt.subplots(ncols=2, nrows=2, figsize=(15, 7))
    
    image, label = sample[0].numpy(), sample[1].numpy()
    
    axarr[i,j].imshow(image)
    axarr[i,j].set_title(str(label))
    
    if i==1 and j==1:
        plt.show()
        i, j = 0, 0
    else:
        if j==1:
            i += 1
            j = 0
        else:
            j += 1

In [None]:
'''
iterator = dataset.make_initializable_iterator()
el = iterator.get_next()
# print(el)

with tf.Session() as sess:
    sess.run(iterator.initializer)
    for i in range(10):
        sample = sess.run(el)
        img = sample[0]
        label = sample[1]

        print(img.shape, label)
        img = np.reshape(img, newshape= (img.shape[0], img.shape[1], img.shape[2]) )
        print(img.max(), img.min())

        f, axarr = plt.subplots(ncols=1, nrows=1, figsize=(1, 1))
        axarr.imshow(img)
        axarr.set_title('Example image')
        plt.show()
'''