In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path
import sys
try:
    from utils.model import make_model, freeze_all_base_model, unfreeze_last_base_model, loss_definition, initial_model, \
        callbacks_definition, train
    from utils.data import filter_binary_labels, optimize_dataset, prepare_sample_dataset, true_or_false, \
        dataset_definition
except ModuleNotFoundError:
    sys.path.insert(0, str(Path('.').resolve().parent))
    from utils.model import make_model, freeze_all_base_model, unfreeze_last_base_model, loss_definition, initial_model, \
        callbacks_definition, train
    from utils.data import filter_binary_labels, optimize_dataset, prepare_sample_dataset, true_or_false, \
        dataset_definition

In [3]:
import tensorflow as tf
import tensorflow.keras.layers as layers
import tensorflow.keras.applications.vgg16 as vgg16
import tensorflow_datasets as tfds
from tensorflow.keras.applications import vgg16, vgg19, densenet, resnet_v2, inception_v3, resnet50, resnet

In [4]:
file_path = '../data/tfrecords/patch_camelyon/resnet152v2'
n_batches_concat = 10
batch_size = 64

In [5]:
train_ds, test_ds, valid_ds, class_names = dataset_definition(sample_dataset='patch_camelyon', batch_size=batch_size, img_height=224, img_width=224)

2022-07-04 23:47:17.809 INFO    absl: Load dataset info from C:\Users\lucas\tensorflow_datasets\patch_camelyon\2.0.0
2022-07-04 23:47:17.820 INFO    absl: Reusing dataset patch_camelyon (C:\Users\lucas\tensorflow_datasets\patch_camelyon\2.0.0)
2022-07-04 23:47:17.821 INFO    absl: Constructing tf.data.Dataset patch_camelyon for split ['train', 'validation'], from C:\Users\lucas\tensorflow_datasets\patch_camelyon\2.0.0
2022-07-04 23:47:20.044 INFO    absl: Load dataset info from C:\Users\lucas\tensorflow_datasets\patch_camelyon\2.0.0
2022-07-04 23:47:20.046 INFO    absl: Reusing dataset patch_camelyon (C:\Users\lucas\tensorflow_datasets\patch_camelyon\2.0.0)
2022-07-04 23:47:20.047 INFO    absl: Constructing tf.data.Dataset patch_camelyon for split ['test'], from C:\Users\lucas\tensorflow_datasets\patch_camelyon\2.0.0


In [6]:
def make_base_model(img_height=224, img_width=224, transfer_learning=True, base_model='vgg16'):
    if transfer_learning:
        weights = 'imagenet'
    else:
        weights = None
    
    if base_model == 'vgg16':
        base_model_net = vgg16.VGG16(include_top=False, weights=weights)
        preprocess_layer = vgg16.preprocess_input
    if base_model == 'vgg19':
        base_model_net = vgg19.VGG19(include_top=False, weights=weights)
        preprocess_layer = vgg19.preprocess_input
    elif base_model == 'densenet201':
        base_model_net = densenet.DenseNet201(include_top=False, weights=weights)
        preprocess_layer = densenet.preprocess_input
    elif base_model == 'densenet169':
        base_model_net = densenet.DenseNet169(include_top=False, weights=weights)
        preprocess_layer = densenet.preprocess_input
    elif base_model == 'densenet121':
        base_model_net = densenet.DenseNet121(include_top=False, weights=weights)
        preprocess_layer = densenet.preprocess_input
    elif base_model == 'resnet152v2':
        base_model_net = resnet_v2.ResNet152V2(include_top=False, weights=weights)
        preprocess_layer = resnet_v2.preprocess_input
    elif base_model == 'resnet50':
        base_model_net = resnet50.ResNet50(include_top=False, weights=weights)
        preprocess_layer = resnet50.preprocess_input
    elif base_model == 'resnet152':
        base_model_net = resnet.ResNet152(include_top=False, weights=weights)
        preprocess_layer = resnet.preprocess_input 
    elif base_model == 'resnet101':
        base_model_net = resnet.ResNet101(include_top=False, weights=weights)
        preprocess_layer = resnet.preprocess_input
    elif base_model == 'inception_v3':
        base_model_net = inception_v3.InceptionV3(include_top=False, weights=weights)
        preprocess_layer = inception_v3.preprocess_input
    
    '''
    data_augmentation = tf.keras.Sequential([
        layers.experimental.preprocessing.RandomFlip('horizontal'),
        layers.experimental.preprocessing.RandomRotation(0.2),
    ])
    '''
    
    inputs = layers.Input(shape=(img_height, img_width, 3))
    #x = data_augmentation(inputs)
    x = preprocess_layer(inputs)
    outputs = base_model_net(x, training=False)
    model = tf.keras.Model(inputs, outputs)
    
    return model

In [7]:
model = make_base_model(img_height=224, img_width=224, transfer_learning=True, base_model='resnet152v2')

In [8]:
from tqdm import tqdm
import numpy as np
import gc

In [9]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))): # if value ist tensor
        value = value.numpy() # get value of tensor
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a floast_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_array(array):
    array = tf.io.serialize_tensor(array)
    return array

def parse_single_image(image, label):
    #define the dictionary -- the structure -- of our single example
    data = {
        'height' : _int64_feature(image.shape[0]),
        'width' : _int64_feature(image.shape[1]),
        'depth' : _int64_feature(image.shape[2]),
        'raw_image' : _bytes_feature(serialize_array(image)),
        'label' : _int64_feature(label)
    }
    #create an Example, wrapping the single features
    out = tf.train.Example(features=tf.train.Features(feature=data))
    
    return out

def write_images(images, labels, filename:str="images"):
    filename= str(Path(filename+".tfrecords").resolve())
    writer = tf.io.TFRecordWriter(filename)
    count = 0
    
    for index in range(len(images)):
        current_image = images[index] 
        current_label = labels[index]
        
        out = parse_single_image(image=current_image, label=current_label)
        writer.write(out.SerializeToString())
        count += 1
    
    writer.close()
    #print(f"Wrote {count} elements to TFRecord")
    return count

def write_tfrecord_dataset(dataset, model, file_path):
    images_concat = None
    labels_concat = None
    batches_count = 0
    batches_count_name = 0
    for images, labels in tqdm(dataset):
        if images_concat is None:
            images_concat = model.predict(images)
        else:
            images_concat = np.concatenate([images_concat, model.predict(images)])

        if labels_concat is None:
            labels_concat = labels.numpy()
        else:
            labels_concat = np.concatenate([labels_concat, labels.numpy()])

        batches_count = batches_count + 1

        if batches_count == n_batches_concat:
            batches_count = 0
            batches_count_name = batches_count_name + 1
            write_images(images_concat, labels_concat, filename=file_path + str(batches_count_name))
            images_concat = None
            labels_concat = None
            gc.collect();

In [10]:
write_tfrecord_dataset(dataset=train_ds, model=model, file_path=file_path + '/train/dataset_')

100%|████████████████████████████████████████████████████████████████████████████| 4096/4096 [1:05:40<00:00,  1.04it/s]


In [11]:
write_tfrecord_dataset(dataset=test_ds, model=model, file_path=file_path + '/test/dataset_')

100%|████████████████████████████████████████████████████████████████████████████████| 512/512 [08:42<00:00,  1.02s/it]


In [12]:
write_tfrecord_dataset(dataset=valid_ds, model=model, file_path=file_path + '/validation/dataset_')

100%|████████████████████████████████████████████████████████████████████████████████| 512/512 [09:10<00:00,  1.08s/it]
