In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import tensorflow as tf
import os
from tensorflow.keras import models 

In [3]:
def load_and_preprocess_image_tensor(path, target_shape):
    """
    Loads and image from PATH and returns a EagerTensor 
    shape: (H, W, C)
    dtype: float32
    """
    image = tf.io.read_file(path)
    return preprocess_image_tensor(image, target_shape)

def preprocess_image_tensor(image, target_shape):
    """Resizes given image"""
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.convert_image_dtype(image, dtype=tf.float32) 
    image -= 0.5
    image *= 2
    image = tf.image.resize(image, 
                           target_shape)
    return image


def get_data_paths(data_dir):
    path_lst = []
    
    #assert dataset_type == 'train' or dataset_type == 'val', "Please specify either train or validation"
    for file in os.listdir(data_dir):
        if file.endswith(".jpg"):
            path_lst.append(os.path.join(data_dir, file))
    return path_lst


def get_dataloader(path_lst, batch_size, target_size=(400, 400)):
    """Returns a batch dataset of size BATCH_SIZE"""
    process_fn = lambda path: load_and_preprocess_image_tensor(path, target_size)
    
    
    train_data = tf.data.Dataset.from_tensor_slices(
                 (tf.constant(path_lst)))

    return train_data.map(process_fn) \
             .shuffle(buffer_size=10000) \
             .batch(batch_size)

In [4]:
style_layers = ['block1_conv1',
                'block2_conv1',
                'block3_conv1', 
                'block4_conv1', 
                'block5_conv1'
               ]
content_layers = ['block5_conv2'] 

In [45]:
vgg.summary()

Model: "vgg16"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, None, None, 3)]   0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, None, None, 64)    1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, None, None, 64)    36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, None, None, 64)    0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, None, None, 128)   73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, None, None, 128)   147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, None, None, 128)   0     

In [8]:
style_outputs = [vgg.get_layer(name).output for name in style_layers]
content_outputs = [vgg.get_layer(name).output for name in content_layers]

In [29]:
path = get_data_paths("test/style")
data_loader = get_dataloader(path, 1)
it = iter(data_loader)

In [31]:
im = next(it)

In [9]:
model = models.Model(vgg.input, style_outputs + content_outputs)

In [35]:
res = model(im)

(1, 400, 400, 64)
(1, 200, 200, 128)
(1, 100, 100, 256)
(1, 50, 50, 512)
(1, 25, 25, 512)
(1, 25, 25, 512)
