-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict_resnet_onsets.py
62 lines (49 loc) · 2.24 KB
/
predict_resnet_onsets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#coding=utf-8
import tensorflow as tf
import numpy as np
import pdb
import cv2
import os
import glob
import slim.nets.resnet_v2 as resnet
from create_tf_record import *
import tensorflow.contrib.slim as slim
def predict(models_path,image_dir,labels_filename,labels_nums, data_format):
[batch_size, resize_height, resize_width, depths] = data_format
labels = np.loadtxt(labels_filename, str, delimiter='\t')
input_images = tf.placeholder(dtype=tf.float32, shape=[None, resize_height, resize_width, depths], name='input')
with slim.arg_scope(resnet.resnet_arg_scope()):
out, end_points = resnet.resnet_v2_50(inputs=input_images, num_classes=labels_nums, is_training=False)
# 将输出结果进行softmax分布,再求最大概率所属类别
score = tf.nn.softmax(out,name='pre')
class_id = tf.argmax(score, 1)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
saver.restore(sess, models_path)
images_list=glob.glob(os.path.join(image_dir,'*.jpg'))
score_total = 0
for image_path in images_list:
im=read_image(image_path,resize_height,resize_width,normalization=True)
im=im[np.newaxis,:]
#pred = sess.run(f_cls, feed_dict={x:im, keep_prob:1.0})
pre_score,pre_label = sess.run([score,class_id], feed_dict={input_images:im})
max_score=pre_score[0,pre_label]
#print("{} is: pre labels:{},name:{} score: {}".format(image_path, pre_label, labels[pre_label], max_score))
if image_path.split(".jpg")[0].split("-")[2] == labels[pre_label]:
score_total += 1
else:
print("{} is predicted as label::{} ".format(image_path,labels[pre_label]))
print("valuation accuracy is {}".format(score_total/len(images_list)))
sess.close()
if __name__ == '__main__':
class_nums=4
image_dir='./onsets/test/B/'
labels_filename='./onsets/label.txt'
models_path='./models/onsets/resnet/model.ckpt-10000'
batch_size = 1 #
resize_height = 224 # 指定存储图片高度
resize_width = 224 # 指定存储图片宽度
depths=3
data_format=[batch_size,resize_height,resize_width,depths]
predict(models_path,image_dir, labels_filename, class_nums, data_format)