In [1]:
import tensorflow as tf
from tensorflow.saved_model import load # type: ignore
from tensorflow.keras import layers # type: ignore
from tensorflow.keras.preprocessing import image_dataset_from_directory # type: ignore
import pathlib
from sklearn.metrics import recall_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score


2025-03-20 13:42:12.098892: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-20 13:42:12.410260: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1742478132.542763  239252 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1742478132.579962  239252 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1742478132.885763  239252 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking 

In [None]:
batch_size = 32
img_height = 224
img_width  = 224

#Adjust the path according to your machine
data_dir_test  = pathlib.Path('data_test/')


test_ds = image_dataset_from_directory(
            data_dir_test,
            image_size=(img_height, img_width),
            batch_size=batch_size)

for image_batch, labels_batch in test_ds:
  print(f"👉The shape of each test batch is {image_batch.shape}")
  print(f"  The shape of each target batch is {labels_batch.shape}")
  break

print("👀 The classes that was used for training the model are:")
for name in test_ds.class_names:
    print(f"- {name}")

num_classes = len(test_ds.class_names)


Found 451 files belonging to 7 classes.


W0000 00:00:1742478150.555050  239252 gpu_device.cc:2341] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


👉The shape of each test batch is (32, 224, 224, 3)
  The shape of each target batch is (32,)
👀 The classes that was used for training the model are:
- cataract
- degeneration
- diabets
- glaucoma
- hypertension
- myopia
- normal


In [3]:
from tensorflow.keras import layers

#Normalization of the image tensors
normalization_layer =   layers.Rescaling(1./255)
normalized_test_ds  =  test_ds.map(lambda x, y: (normalization_layer(x), y))

In [None]:

# Load the model
model = load("saved_model/resnet_50")


In [5]:
print(list(model.signatures.keys()))  # E.g., ['serving_default']

['serve', 'serving_default']


In [6]:
infer = model.signatures["serving_default"]
print(infer.structured_input_signature)  # Input details
print(infer.structured_outputs)          # Output details

((), {'keras_tensor': TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32, name='keras_tensor')})
{'output_0': TensorSpec(shape=(None, 7), dtype=tf.float32, name='output_0')}


In [7]:
y_true = []
y_pred = []
for images, labels in normalized_test_ds:
    predictions = infer(images)  # Use the serving function
    predicted_classes = tf.argmax(predictions['output_0'], axis=1)  # Adjust for your output key
    y_true.extend(labels.numpy())
    y_pred.extend(predicted_classes.numpy())


2025-03-20 13:44:07.703953: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [None]:
for i in range(7):
    class_id = i
    class_recall = recall_score(
        [1 if label == class_id else 0 for label in y_true],
        [1 if pred == class_id else 0 for pred in y_pred], average = 'macro',
    )
    class_accuracy = accuracy_score(
        [1 if label == class_id else 0 for label in y_true],
        [1 if pred == class_id else 0 for pred in y_pred])

    print(f"Recall for {test_ds.class_names[i]}: {class_recall:.2} \t Accuracy for {test_ds.class_names[i]}: {class_accuracy:.2}" )

Recall for cataract: 0.49 	 Accuracy for cataract: 0.94
Recall for degeneration: 0.5 	 Accuracy for degeneration: 0.95
Recall for diabets: 0.5 	 Accuracy for diabets: 0.72
Recall for glaucoma: 0.5 	 Accuracy for glaucoma: 0.94
Recall for hypertension: 0.6 	 Accuracy for hypertension: 0.98
Recall for myopia: 0.61 	 Accuracy for myopia: 0.46
Recall for normal: 0.5 	 Accuracy for normal: 0.5
