In [1]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, CategoryEncoding, Input, Rescaling
from tensorflow.keras.metrics import Precision, Recall, BinaryAccuracy
from tensorflow.keras.models import load_model
import keras
from keras import backend as K
import os
import cv2
import imghdr
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [None]:
# provide custom loss function

weights = tf.constant([42, 2.5])

def weighted_cross_entropy(y_true, y_pred):
    # compute softmax cross entropy
    ce = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
    
    # get class from one-hot encoded label
    y_true_class = tf.argmax(y_true, axis=1)
    
    # compute weights based on true class
    weights_ce = tf.gather(weights, y_true_class)
    
    # apply weights
    weighted_ce = ce * weights_ce
    
    return tf.reduce_mean(weighted_ce)

In [2]:
# load base model w/ standard loss
model_path = os.path.join('..', 'models', 'dense3_softmax', 'particledrag_dataset=230511_epochs=15_BinaryCrossentropy.h5')

model = load_model(model_path)

Metal device set to: Apple M1 Max

systemMemory: 32.00 GB
maxCacheSize: 10.67 GB



In [None]:
# load model
model_path = os.path.join('..', 'models', 'dense3_softmax', 'particledrag_dataset=230511_epochs=15_WeightedCategoricalCrossentropy.h5')

model = load_model(model_path, custom_objects={'weighted_cross_entropy': weighted_cross_entropy})

In [3]:
# load test set
test_dir = os.path.join('..', '230503_Hot Run', 'data')

test_data = tf.keras.utils.image_dataset_from_directory(test_dir, image_size=(480, 640), label_mode='categorical')

Found 27805 files belonging to 2 classes.


In [5]:
# compare to true labels
y_true = np.concatenate([y for x, y in test_data], axis=0)
y_pred = model.predict(test_data)

y_pred_classes = np.argmax(y_pred, axis=1)
y_true_classes = np.argmax(y_true, axis=1)

accuracy = accuracy_score(y_true_classes, y_pred_classes)
precision = precision_score(y_true_classes, y_pred_classes, average='weighted')
f1 = f1_score(y_true_classes, y_pred_classes, average='weighted')

# Print the metrics
print('Accuracy:', accuracy)
print('Precision:', precision)
print('F1 score:', f1)

Accuracy: 0.8470059341844992
Precision: 0.759770240765735
Recall: 0.8470059341844992
F1 score: 0.7941972649608086
