In [1]:
import tensorflow as tf
from tensorflow.keras.models import load_model
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
from prettytable import PrettyTable

from PIL import Image
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
import cv2

import shutil
import glob
import random



In [2]:
TRAIN_DATA_PATH= os.path.join(os.getcwd(), '..','data','train')
TEST_DATA_PATH = os.path.join(os.getcwd(), '..','data','test')
MODEL_PATH = os.path.join(os.getcwd(), '..','data','model.h5')

IMAGE_SIZE = 32; 


In [3]:
# Lädt ein Bild
def load_image(path,image_size=IMAGE_SIZE):
    img =cv2.imread(path)
    img=cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    tmp = img.reshape([IMAGE_SIZE, IMAGE_SIZE,1])


    return np.array(tmp)/255

In [4]:
# Test auf Dataset
def test_dataset(model, labels_decoded, path, image_size=IMAGE_SIZE):
    correct_matches = 0
    result_table = PrettyTable()
    result_table.field_names = ["Datei ", "Ist", "Dekodiert", "Match?"]

    for filename in sorted(os.listdir(path)):

        if(filename.startswith('.') == False):

            current_wavelength = filename[0:7]
            print (current_wavelength)
            image_path = os.path.join(path,filename)
                
            test_image = load_image(image_path)
            predictions = model.predict(test_image.reshape((1,IMAGE_SIZE,IMAGE_SIZE,1)))
                
            index_max_predictions = np.argmax(predictions)
            print('index_max_predictions:',index_max_predictions, current_wavelength, labels_decoded[index_max_predictions])
            decode_wavelength = labels_decoded[index_max_predictions]

            # Passt oder nicht?
            if( str.upper(current_wavelength) == str.upper(decode_wavelength)):
                result_table.add_row([image_path, current_wavelength, decode_wavelength, "✅" ])
                correct_matches = correct_matches + 1 
            else:
                result_table.add_row([image_path, current_wavelength, decode_wavelength, "❌" ])


    print(result_table)

In [5]:
# copy some training data to directory test, if no test data is provided
#
shutil.rmtree(TEST_DATA_PATH)
os.mkdir(TEST_DATA_PATH)
for directory in sorted(os.listdir(TRAIN_DATA_PATH)):
    if(directory.startswith('.') == False):
        p = os.path.join(TRAIN_DATA_PATH,directory)
        for filename in sorted(os.listdir(p)):
            if random.random() > 0.99:
                
                shutil.copyfile(os.path.join(p,filename),os.path.join(TEST_DATA_PATH,filename))

        

        



In [6]:
labels_decoded = []
for directory in sorted(os.listdir(TRAIN_DATA_PATH)):
    if(directory.startswith('.') == False):
        labels_decoded.append(directory)

model = load_model(MODEL_PATH)
test_dataset(model, labels_decoded, TEST_DATA_PATH)

2023-07-15 16:22:26.411160: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1 Pro
2023-07-15 16:22:26.411185: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2023-07-15 16:22:26.411189: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2023-07-15 16:22:26.411221: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-07-15 16:22:26.411236: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


5852.48
index_max_predictions: 0 5852.48 5852.48
6029.99
index_max_predictions: 4 6029.99 6029.99
6128.44
index_max_predictions: 7 6128.44 6128.44
6128.44
index_max_predictions: 7 6128.44 6128.44
6128.44
index_max_predictions: 7 6128.44 6128.44
6163.59


2023-07-15 16:22:26.718927: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.


index_max_predictions: 9 6163.59 6163.59
6163.59
index_max_predictions: 9 6163.59 6163.59
6163.59
index_max_predictions: 9 6163.59 6163.59
6217.28
index_max_predictions: 10 6217.28 6217.28
6217.28
index_max_predictions: 10 6217.28 6217.28
6266.49
index_max_predictions: 11 6266.49 6266.49
6304.78
index_max_predictions: 12 6304.78 6304.78
6382.99
index_max_predictions: 14 6382.99 6382.99
6506.52
index_max_predictions: 16 6506.52 6506.52
6532.88
index_max_predictions: 17 6532.88 6532.88
6532.88
index_max_predictions: 17 6532.88 6532.88
6598.95
index_max_predictions: 18 6598.95 6598.95
6598.95
index_max_predictions: 18 6598.95 6598.95
6929.46
index_max_predictions: 21 6929.46 6929.46
7173.93
index_max_predictions: 23 7173.93 7173.93
7245.16
index_max_predictions: 24 7245.16 7245.16
7245.16
index_max_predictions: 24 7245.16 7245.16
7438.89
index_max_predictions: 25 7438.89 7438.89
7438.89
index_max_predictions: 25 7438.89 7438.89
+------------------------------------------------------------