In [6]:
from PIL import Image

import numpy as np
import numpy.linalg as la
import tensorflow as tf
from tensorflow.keras.preprocessing import image as kimage
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.models import Model


In [12]:
class ClientResNet(tf.keras.Model):
    def __init__(self, *args, **kwargs):
        super(ClientResNet, self).__init__(*args, **kwargs)
        
        # Load ResNet50 pre-trained model without top (fully connected) layers
        resnet_base = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
        resnet_base.trainable = False
        
        # Get the output of the first three convolutional layers
        middle_layer = resnet_base.get_layer('conv3_block4_out')
        self.seq0 = Model(inputs=resnet_base.input, outputs=middle_layer.output)

    def preprocess_image(self, img):
        img_array = kimage.img_to_array(img)
        expand_img = np.expand_dims(img_array, axis=0)
        return preprocess_input(expand_img)

    def predict(self, inputs):
        preprocessed_img = self.preprocess_image(inputs)
        return self.seq0(preprocessed_img)
    
class ServerResnet(tf.keras.Model):
    def __init__(self, *args, **kwargs):
        super(ServerResnet, self).__init__(*args, **kwargs)
        
        # Load ResNet50 pre-trained model without top (fully connected) layers
        resnet_base = ResNet50(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
        resnet_base.trainable = False
        
        # Get the output of the last two convolutional layer before Pooling
        middle_layer = resnet_base.get_layer('conv3_block4_out')
        last_conv_layer = resnet_base.get_layer('conv5_block3_out')

        self.seq1 = tf.keras.Sequential([
            Model(inputs=middle_layer.output, outputs=last_conv_layer.output),
            tf.keras.layers.GlobalMaxPooling2D(),
        ])

    def predict(self, inputs):
        server_embeddings = self.seq1(inputs)

        # Normalize data
        flatten_result = tf.keras.layers.Flatten()(server_embeddings)
        result_normalized = flatten_result / la.norm(flatten_result)
        return result_normalized

def get_embeddings(img):
    # Run prediction
    client_embeddings = ClientModel.predict(img)
    server_embeddings = ServerModel.predict(client_embeddings)
    print(server_embeddings.shape)

    return server_embeddings

ClientModel = ClientResNet()
ServerModel = ServerResnet()

image_path = "jeans.png"
img = Image.open(image_path)
image = img.resize((224, 224)).convert("RGB")
get_embeddings(image)


(1, 2048)


<tf.Tensor: shape=(1, 2048), dtype=float32, numpy=
array([[0.02651525, 0.073531  , 0.        , ..., 0.        , 0.00534359,
        0.01297403]], dtype=float32)>