In [1]:
%load_ext autoreload
%autoreload 2

In [27]:
import tensorflow as tf
import os
from tensorflow.keras import models 

In [24]:
def load_and_preprocess_image_tensor(path, target_shape):
    """
    Loads and image from PATH and returns a EagerTensor 
    shape: (H, W, C)
    dtype: float32
    """
    image = tf.io.read_file(path)
    return preprocess_image_tensor(image, target_shape)

def preprocess_image_tensor(image, target_shape):
    """Resizes given image"""
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32) 
    image -= 0.5
    image *= 2
    image = tf.image.resize(image, 
                           target_shape)
    return image


def get_data_paths(data_dir):
    path_lst = []
    
    #assert dataset_type == 'train' or dataset_type == 'val', "Please specify either train or validation"
    for file in os.listdir(data_dir):
        if file.endswith(".jpg"):
            path_lst.append(os.path.join(data_dir, file))
    return path_lst


def get_dataloader(path_lst, batch_size, target_size=(400, 400)):
    """Returns a batch dataset of size BATCH_SIZE"""
    process_fn = lambda path: load_and_preprocess_image_tensor(path, target_size)
    
    
    train_data = tf.data.Dataset.from_tensor_slices(
                 (tf.constant(path_lst)))

    return train_data.map(process_fn) \
             .shuffle(buffer_size=10000) \
             .batch(batch_size)

In [17]:
style_layers = ['block1_conv1',
                'block2_conv1',
                'block3_conv1', 
                'block4_conv1', 
                'block5_conv1'
               ]
content_layers = ['block5_conv2'] 

In [9]:
vgg = tf.keras.applications.vgg19.VGG19(include_top=False, weights='imagenet')

In [6]:
vgg.trainable = False

In [18]:
style_outputs = [vgg.get_layer(name).output for name in style_layers]
content_outputs = [vgg.get_layer(name).output for name in content_layers]

In [19]:
model = models.Model(vgg.input, style_outputs + content_outputs)

In [31]:
path_lst = get_data_paths('train/style')
dat = get_dataloader(path_lst, 1, target_size=(400, 400))

In [None]:
img = next(iter(dat))

In [20]:
vgg.input

<tf.Tensor 'input_2:0' shape=(None, None, None, 3) dtype=float32>