In [5]:
import tensorflow as tf
import wandb
import matplotlib.pyplot as plt
import time
import os

from input import load
from evaluation import evaluate_multiclass

In [7]:
wandb.init(project="diabetic_retinopathy", entity="stuttgartteam8", mode="disabled") 
config = wandb.config


ds_train, ds_val, ds_test = load(config)

model_filename = "_".join(config.evaluate_run.split("/")) + ".h5"

# check if weights file was already downloaded before
if os.path.isfile(model_filename):
    print("Using model from local .h5 file")
else:
    print("Download model from wandb")
    api = wandb.Api()
    run = api.run(config.evaluate_run)
    run.file("model.h5").download(replace=True)
    time.sleep(1)
    os.rename("model.h5", model_filename)


model = tf.keras.models.load_model(model_filename, compile=False)
print(model.summary())


Running in multiclass classification mode!
Running in multiclass classification mode!
Running in multiclass classification mode!
Using model from local .h5 file
Model: "multi_class_resnet50"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 tf.cast (TFOpLambda)        (None, 224, 224, 3)       0         
                                                                 
 tf.__operators__.getitem (S  (None, 224, 224, 3)      0         
 licingOpLambda)                                                 
                                                                 
 tf.nn.bias_add (TFOpLambda)  (None, 224, 224, 3)      0         
                                                                 
 resnet50 (Functional)       (None, 7, 7, 2048)        23587712  
                 

In [8]:

print("--- Validation Scores ---")
acc, p, r, f1, confm, quadratic_weighted_kappa = evaluate_multiclass(config, model, ds_val)
print(f"Accuracy: {acc}")
print(f"Precision: {p}")
print(f"Recall: {r}")
print(f"f1-Score: {f1}")
print(f"Confusion-Matrix: \n{confm}")
print(f"Quadratic WeightedKappa: {quadratic_weighted_kappa}")

print("--- Test Scores ---")
acc, p, r, f1, confm, quadratic_weighted_kappa = evaluate_multiclass(config, model, ds_test)
print(f"Accuracy: {acc}")
print(f"Precision: {p}")
print(f"Recall: {r}")
print(f"f1-Score: {f1}")
print(f"Confusion-Matrix: \n{confm}")
print(f"Quadratic WeightedKappa: {quadratic_weighted_kappa}")

--- Validation Scores ---
Accuracy: 0.575
Precision: 0.4685465838509317
Recall: 0.4352450980392156
f1-Score: 0.45128232071718855
Confusion-Matrix: 
[[22  1  1  1  0]
 [ 5  0  0  0  0]
 [ 7  1  9  2  4]
 [ 0  0  6 12  3]
 [ 0  0  0  3  3]]
Quadratic WeightedKappa: 0.4254330376003379
--- Test Scores ---
Accuracy: 0.5833333333333334
Precision: 0.4740458426772468
Recall: 0.470954535660418
f1-Score: 0.47249513299145335
Confusion-Matrix: 
[[27  2  2  0  1]
 [ 3  0  2  0  0]
 [ 7  1 14  3  6]
 [ 2  0  2  9  4]
 [ 3  0  1  1  6]]
Quadratic WeightedKappa: 0.43512797881729925
