Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c41af8c
commit 1c84896
Showing
5 changed files
with
182 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import tensorflow as tf | ||
import numpy as np | ||
import os | ||
from PIL import Image | ||
import random | ||
|
||
|
||
class CNN(object): | ||
def __init__(self, image_height, image_width, max_captcha, char_set, model_save_dir): | ||
# 初始值 | ||
self.image_height = image_height | ||
self.image_width = image_width | ||
self.max_captcha = max_captcha | ||
self.char_set = char_set | ||
self.char_set_len = len(char_set) | ||
self.model_save_dir = model_save_dir # 模型路径 | ||
with tf.name_scope('parameters'): | ||
self.w_alpha = 0.01 | ||
self.b_alpha = 0.1 | ||
# tf初始化占位符 | ||
with tf.name_scope('data'): | ||
self.X = tf.placeholder(tf.float32, [None, self.image_height * self.image_width]) # 特征向量 | ||
self.Y = tf.placeholder(tf.float32, [None, self.max_captcha * self.char_set_len]) # 标签 | ||
self.keep_prob = tf.placeholder(tf.float32) # dropout值 | ||
|
||
@staticmethod | ||
def convert2gray(img): | ||
""" | ||
图片转为灰度图,如果是3通道图则计算,单通道图则直接返回 | ||
:param img: | ||
:return: | ||
""" | ||
if len(img.shape) > 2: | ||
r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2] | ||
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b | ||
return gray | ||
else: | ||
return img | ||
|
||
def text2vec(self, text): | ||
""" | ||
转标签为oneHot编码 | ||
:param text: str | ||
:return: numpy.array | ||
""" | ||
text_len = len(text) | ||
if text_len > self.max_captcha: | ||
raise ValueError('验证码最长{}个字符'.format(self.max_captcha)) | ||
|
||
vector = np.zeros(self.max_captcha * self.char_set_len) | ||
|
||
for i, ch in enumerate(text): | ||
idx = i * self.char_set_len + self.char_set.index(ch) | ||
vector[idx] = 1 | ||
return vector | ||
|
||
def model(self): | ||
x = tf.reshape(self.X, shape=[-1, self.image_height, self.image_width, 1]) | ||
print(">>> input x: {}".format(x)) | ||
|
||
# 卷积层1 | ||
wc1 = tf.get_variable(name='wc1', shape=[3, 3, 1, 32], dtype=tf.float32, | ||
initializer=tf.contrib.layers.xavier_initializer()) | ||
bc1 = tf.Variable(self.b_alpha * tf.random_normal([32])) | ||
conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, wc1, strides=[1, 1, 1, 1], padding='SAME'), bc1)) | ||
conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') | ||
conv1 = tf.nn.dropout(conv1, self.keep_prob) | ||
|
||
# 卷积层2 | ||
wc2 = tf.get_variable(name='wc2', shape=[3, 3, 32, 64], dtype=tf.float32, | ||
initializer=tf.contrib.layers.xavier_initializer()) | ||
bc2 = tf.Variable(self.b_alpha * tf.random_normal([64])) | ||
conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, wc2, strides=[1, 1, 1, 1], padding='SAME'), bc2)) | ||
conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') | ||
conv2 = tf.nn.dropout(conv2, self.keep_prob) | ||
|
||
# 卷积层3 | ||
wc3 = tf.get_variable(name='wc3', shape=[3, 3, 64, 128], dtype=tf.float32, | ||
initializer=tf.contrib.layers.xavier_initializer()) | ||
bc3 = tf.Variable(self.b_alpha * tf.random_normal([128])) | ||
conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, wc3, strides=[1, 1, 1, 1], padding='SAME'), bc3)) | ||
conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') | ||
conv3 = tf.nn.dropout(conv3, self.keep_prob) | ||
print(">>> convolution 3: ", conv3.shape) | ||
next_shape = conv3.shape[1] * conv3.shape[2] * conv3.shape[3] | ||
|
||
# 全连接层1 | ||
wd1 = tf.get_variable(name='wd1', shape=[next_shape, 1024], dtype=tf.float32, | ||
initializer=tf.contrib.layers.xavier_initializer()) | ||
bd1 = tf.Variable(self.b_alpha * tf.random_normal([1024])) | ||
dense = tf.reshape(conv3, [-1, wd1.get_shape().as_list()[0]]) | ||
dense = tf.nn.relu(tf.add(tf.matmul(dense, wd1), bd1)) | ||
dense = tf.nn.dropout(dense, self.keep_prob) | ||
|
||
# 全连接层2 | ||
wout = tf.get_variable('name', shape=[1024, self.max_captcha * self.char_set_len], dtype=tf.float32, | ||
initializer=tf.contrib.layers.xavier_initializer()) | ||
bout = tf.Variable(self.b_alpha * tf.random_normal([self.max_captcha * self.char_set_len])) | ||
|
||
with tf.name_scope('y_prediction'): | ||
y_predict = tf.add(tf.matmul(dense, wout), bout) | ||
|
||
return y_predict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
识别图像的类,为了快速进行多次识别可以调用此类下面的方法: | ||
R = Recognizer(image_height, image_width, max_captcha) | ||
for i in range(10): | ||
r_img = Image.open(str(i) + ".jpg") | ||
t = R.rec_image(r_img) | ||
简单的图片每张基本上可以达到毫秒级的识别速度 | ||
""" | ||
import tensorflow as tf | ||
import numpy as np | ||
from PIL import Image | ||
from conf.sample import sample_conf | ||
from cnnlib.network import CNN | ||
|
||
|
||
class Recognizer(CNN): | ||
def __init__(self, image_height, image_width, max_captcha, char_set, model_save_dir): | ||
# 初始化变量 | ||
super(Recognizer, self).__init__(image_height, image_width, max_captcha, char_set, model_save_dir) | ||
|
||
# 新建图和会话 | ||
self.g = tf.Graph() | ||
self.sess = tf.Session(graph=self.g) | ||
# 使用指定的图和会话 | ||
with self.g.as_default(): | ||
# 迭代循环前,写出所有用到的张量的计算表达式,如果写在循环中,会发生内存泄漏,拖慢识别的速度 | ||
# tf初始化占位符 | ||
self.X = tf.placeholder(tf.float32, [None, self.image_height * self.image_width]) # 特征向量 | ||
self.Y = tf.placeholder(tf.float32, [None, self.max_captcha * self.char_set_len]) # 标签 | ||
self.keep_prob = tf.placeholder(tf.float32) # dropout值 | ||
# 加载网络和模型参数 | ||
self.y_predict = self.model() | ||
self.predict = tf.argmax(tf.reshape(self.y_predict, [-1, self.max_captcha, self.char_set_len]), 2) | ||
saver = tf.train.Saver() | ||
with self.sess.as_default() as sess: | ||
saver.restore(sess, self.model_save_dir) | ||
|
||
# def __del__(self): | ||
# self.sess.close() | ||
# print("session close") | ||
|
||
def rec_image(self, img): | ||
# 读取图片 | ||
img_array = np.array(img) | ||
test_image = self.convert2gray(img_array) | ||
test_image = test_image.flatten() / 255 | ||
# 使用指定的图和会话 | ||
with self.g.as_default(): | ||
with self.sess.as_default() as sess: | ||
text_list = sess.run(self.predict, feed_dict={self.X: [test_image], self.keep_prob: 1.}) | ||
|
||
# 获取结果 | ||
predict_text = text_list[0].tolist() | ||
p_text = "" | ||
for p in predict_text: | ||
p_text += str(self.char_set[p]) | ||
|
||
# 返回识别结果 | ||
return p_text | ||
|
||
|
||
def main(): | ||
image_height = sample_conf["image_height"] | ||
image_width = sample_conf["image_width"] | ||
max_captcha = sample_conf["max_captcha"] | ||
char_set = sample_conf["char_set"] | ||
model_save_dir = sample_conf["model_save_dir"] | ||
R = Recognizer(image_height, image_width, max_captcha, char_set, model_save_dir) | ||
r_img = Image.open("./sample/test/2b3n_6915e26c67a52bc0e4e13d216eb62b37.jpg") | ||
t = R.rec_image(r_img) | ||
print(t) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters