In [None]:
import tensorflow as tf
tf.config.optimizer.set_jit(True)

class WrappedModel(tf.Module):
    def __init__(self):
        super(WrappedModel, self).__init__()
        self.model = tf.keras.applications.ResNet50()
    @tf.function
    def __call__(self, x):
        return self.model(x)

model = WrappedModel()
call = model.__call__.get_concrete_function(tf.TensorSpec([None, None, None, None],
                                                         tf.float32, name='input_0'))

tf.saved_model.save(model,
                   'models/simple-tensorflow-model/1/model.savedmodel',
                   signatures=call)

import json

with open('./imagenet-simple-labels.json') as file:
    labels = json.load(file)
    
#print(labels[:5])

import numpy as np
from PIL import Image

img_path = './assets/goldfish.jpg'
image_pil = Image.open(img_path)
#image_pil

from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions
from tensorflow.keras.preprocessing import image

model = ResNet50(weights='imagenet')

img = image.load_img(img_path, target_size=(224, 224))
image_numpy = image.img_to_array(img)
image_numpy = np.expand_dims(image_numpy, axis=0)
image_numpy = preprocess_input(image_numpy)

"""preds = model.predict(image_numpy)
print('Predicted:', decode_predictions(preds, top=3)[0])"""

configuration = """
name: "simple-tensorflow-model"
platform: "tensorflow_savedmodel"
max_batch_size: 32
input [
  {
      name: "input_0"
      data_type: TYPE_FP32
      format: FORMAT_NHWC
      dims: [ 224, 224, 3 ]
  }
]
output {
    name: "output_0"
    data_type: TYPE_FP32
    dims: [ 1000 ]
}
"""

with open('./models/simple-tensorflow-model/config.pbtxt', 'w') as file:
    file.write(configuration)

#!sleep 45
#!curl -v triton:8000/v2/health/ready
#!curl -v triton:8000/v2/models/simple-tensorflow-model

import tritonclient.http as tritonhttpclient
from tritonclient.utils import triton_to_np_dtype


VERBOSE = False
input_name = 'input_0'
input_shape = (1, 224, 224, 3)
input_dtype = 'FP32'
output_name = 'output_0'
model_name = 'simple-tensorflow-model'
url = 'triton:8000'
model_version = '1'


triton_client = tritonhttpclient.InferenceServerClient(url=url, verbose=VERBOSE)
model_metadata = triton_client.get_model_metadata(model_name=model_name, model_version=model_version)
model_config = triton_client.get_model_config(model_name=model_name, model_version=model_version)


input0 = tritonhttpclient.InferInput(input_name, input_shape, input_dtype)
input0.set_data_from_numpy(image_numpy, binary_data=False)

output = tritonhttpclient.InferRequestedOutput(output_name, binary_data=False)
response = triton_client.infer(model_name, model_version=model_version,
                              inputs=[input0], outputs=[output])

logits = response.as_numpy(output_name)
logits = np.asarray(logits, dtype=np.float32)

print(labels[np.argmax(logits)])