In [None]:
# Import libraries
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub

In [None]:
# Parse test images
def decode_test(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_png(img, channels=3)
    img = tf.cast(img, tf.float32)
    img = tf.image.resize(img, [IMAGE_SIZE, IMAGE_SIZE], antialias=True)/255
    return img

In [None]:
# Define variable
MODEL_HANDLE = "https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b3/feature_vector/2"
WEIGHT_PATH = "../model_weight/CL_EffNetv2-B3_weights.h5"
IMAGE_SIZE = 300
CLASS_NAMES = ['normal', 'pneumonia', 'COVID-19']

In [None]:
# Build model and load the weight
feature_extractor_layer = hub.KerasLayer(MODEL_HANDLE, input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3), trainable=False)
model = tf.keras.Sequential([
    feature_extractor_layer,
    tf.keras.layers.Dropout(rate=0.2),
    tf.keras.layers.Dense(len(CLASS_NAMES), activation='softmax')
])   
model.load_weights(WEIGHT_PATH)
model.summary()

In [None]:
# Load test dataset
tests_df = pd.read_csv('../dataset/test.csv')
tests_df['path'] = '../dataset/clahe/test/'+ tests_df.label + '/' + tests_df.filename

test_ds = tf.data.Dataset.from_tensor_slices(tests_df.path) 
test_ds = test_ds.map(decode_test,num_parallel_calls=AUTOTUNE).batch(len(tests_df))

test_index = np.argmax(tests_df[CLASS_NAMES].values, axis=1)
test_label = tests_df.label.values

# Make prediction
test_pred = model.predict(test_ds)
pred_index = np.argmax(test_pred, axis=1)
pred_label = np.array(CLASS_NAMES)[pred_index]

# Create classification report
print(classification_report(test_index, pred_index, target_names=CLASS_NAMES, zero_division=0, digits=4))
print('f1_score        :', f1_score(test_index, pred_index, average='micro'))
print('accuracy_score  :', accuracy_score(test_index, pred_index))

# Plot the confusion matrix and ROC curve
cm = skplt.metrics.plot_confusion_matrix(test_label, pred_label, figsize=(8, 8), normalize=False)
roc = skplt.metrics.plot_roc(test_index, test_pred, figsize=(10,8))