In [None]:
import tensorflow as tf
import os
import logging
import numpy as np
import pandas as pd
import seaborn as sns
from pytz import timezone
from datetime import datetime
from config import *
from utils import *
from import_data import Import_EfficientNetB7_data, Import_TripletNet_test_data
from tensorflow.keras.metrics import Precision, Recall
from tensorflow.keras.optimizers import Adam
from sklearn.metrics import confusion_matrix, classification_report
from model.EfficientNetB7 import EfficientNetB7_model
from model.TripletNet import TripletNet_model
from tensorflow.keras.models import Model, load_model

In [None]:
physical_devices = tf.config.experimental.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
checkpoint_path = os.path.join(CHECKPOINT_PATH, os.listdir(CHECKPOINT_PATH)[-1])
if checkpoint_path:
    print(f"Checkpoint found: {checkpoint_path}")
    model = load_model(checkpoint_path)
else:
    print("No checkpoint found")

In [None]:
if MODEL_NAME == 'TripletNet':
	train_triplet_generator, test_triplet_generator = Import_TripletNet_test_data(IMAGE_SIZE, BATCH_SIZE, train_data_path=TRAIN_DATA_PATH, test_data_path=TEST_DATA_PATH).build_generators()
	model = model.get_layer('model')
	model = Model(inputs=model.input, outputs=model.output)
    
elif MODEL_NAME == 'EfficientNetB7':
	test_generator = Import_EfficientNetB7_data(IMAGE_SIZE, BATCH_SIZE, test_data_path=TEST_DATA_PATH).build_generators('test')

In [None]:
def evaluate_triplet_model(test_triplet_generator, database, model, output_path):
    y_true = []
    y_pred = []
    total_batches = len(test_triplet_generator)
    batch_count = 0

    print("Starting to evaluate batches...")
    for batch, labels in test_triplet_generator:
        if batch_count >= total_batches:
            break 
        batch_count += 1 
        print(f"Processing batch {batch_count}/{total_batches}") 
        embeddings = model.predict(batch, verbose=0) # (batch, 128)
        labels = labels.astype(int) # (batch, )
        for emb, true_label in zip(embeddings, labels):
            pred_label, _ = predict_closest_embedding(emb, database)
            y_true.append(true_label)
            y_pred.append(pred_label)
    
    print("Saving results...")
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='macro')
    recall = recall_score(y_true, y_pred, average='macro')

    conf_mat = confusion_matrix(y_true, y_pred)
    df_conf_mat = pd.DataFrame(conf_mat, columns=[str(i) for i in range(conf_mat.shape[0])],
                            index=[str(i) for i in range(conf_mat.shape[1])])
    sns_heatmap = sns.heatmap(data=df_conf_mat, annot=True, fmt='d', linewidths=.5, cmap='BuGn_r')
    sns_heatmap.get_figure().savefig(f"{output_path}/confusion_matrix.png")

    target_names = [str(i) for i in range(conf_mat.shape[0])]
    report = classification_report(y_true, y_pred, digits=5, target_names=target_names)

    with open(f"{output_path}/result.txt", "w") as file:
        file.write(f"test_accuracy: {accuracy}, test_precision: {precision}, test_recall: {recall}\n")
        file.write(report)

    print(f"test_accuracy: {accuracy}, test_precision: {precision}, test_recall: {recall}")
    print(report)

In [None]:
if not os.path.exists(TEST_RESULT_FILE_PATH):
        os.makedirs(TEST_RESULT_FILE_PATH) 

if MODEL_NAME == 'TripletNet':
    database = create_embedding_database(train_triplet_generator, model)
    evaluate_triplet_model(test_triplet_generator, database, model, TEST_RESULT_FILE_PATH)
elif MODEL_NAME == 'EfficientNetB7':
    eval = model.evaluate(test_generator)  # [test_loss, test_accuracy, test_precision, test_recall]

    y_pred = np.argmax(model.predict(test_generator), axis=-1)
    y_true = test_generator.labels
    np.save(f"{TEST_RESULT_FILE_PATH}/y_pred.npy", y_pred)
    np.save(f"{TEST_RESULT_FILE_PATH}/y_true.npy", y_true)

    conf_mat = confusion_matrix(y_true, y_pred)
    df_conf_mat = pd.DataFrame(conf_mat, columns=[str(i) for i in range(conf_mat.shape[0])],
                               index=[str(i) for i in range(conf_mat.shape[1])])
    sns_heatmap = sns.heatmap(data=df_conf_mat, annot=True, fmt='d', linewidths=.5, cmap='BuGn_r')
    sns_heatmap.get_figure().savefig(f"{TEST_RESULT_FILE_PATH}/confusion_matrix.png")

    target_names = [str(i) for i in range(conf_mat.shape[0])]
    report = classification_report(y_true, y_pred, digits=5, target_names=target_names)

    with open(f"{TEST_RESULT_FILE_PATH}/result.txt", "w") as file:
        file.write(f"test_loss: {eval[0]}, test_accuracy: {eval[1]}, test_precision: {eval[2]}, test_recall: {eval[3]}\n")
        file.write(report)
    print(f'test_loss: {eval[0]}, test_accuracy: {eval[1]}, test_precision: {eval[2]}, test_recall: {eval[3]}')
    print(report)