-
Notifications
You must be signed in to change notification settings - Fork 0
/
test-quality-net.py
70 lines (59 loc) · 2.58 KB
/
test-quality-net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import argparse
from tqdm import tqdm
import tensorflow as tf
from tensorflow import keras
from quality_net_utilities import *
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Trains a feature-aware normalization model.')
parser.add_argument('--dataset_path',dest='dataset_path',
action='store',type=str,default=None,
help="Path to hdf5 dataset.")
parser.add_argument('--checkpoint_path',dest='checkpoint_path',
action='store',type=str,default=None,
help="Path to hdf5 dataset.")
parser.add_argument('--input_height',dest = 'input_height',
action = 'store',type = int,default = 256,
help = 'The file extension for all images.')
parser.add_argument('--input_width',dest = 'input_width',
action = 'store',type = int,default = 256,
help = 'The file extension for all images.')
parser.add_argument('--batch_size',dest = 'batch_size',
action = 'store',type = int,default = 4,
help = 'Size of mini batch.')
args = parser.parse_args()
print("Setting up network...")
quality_net = keras.models.load_model(args.checkpoint_path)
print("Setting up data generator...")
data_generator = DataGenerator(args.dataset_path,None)
def load_generator():
for image,label in data_generator.generate():
yield image,label
generator = load_generator
output_types = (tf.float32,tf.float32)
output_shapes = (
tf.TensorShape((args.input_height,args.input_width,3)),
tf.TensorShape((1)))
tf_dataset = tf.data.Dataset.from_generator(
generator,output_types=output_types,output_shapes=output_shapes)
tf_dataset = tf_dataset.batch(args.batch_size)
tf_dataset = tf_dataset.prefetch(args.batch_size*5)
print("Setting up testing...")
auc = tf.keras.metrics.AUC()
acc = Accuracy()
rec = Recall()
prec = Precision()
print("Testing...")
for image,c in tqdm(tf_dataset):
prediction = quality_net(image)
auc.update_state(c,prediction)
acc.update_state(c,prediction)
rec.update_state(c,prediction)
prec.update_state(c,prediction)
print('{},{}'.format(
'AUC',float(auc.result().numpy())))
print('{},{}'.format(
'Accuracy',float(acc.result().numpy())))
print('{},{}'.format(
'Recall',float(rec.result().numpy())))
print('{},{}'.format(
'Precision',float(prec.result().numpy())))