Skip to content

Commit

Permalink
update cnnlib
Browse files Browse the repository at this point in the history
  • Loading branch information
nickliqian committed Jun 1, 2019
1 parent c41af8c commit 1c84896
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 3 deletions.
103 changes: 103 additions & 0 deletions cnnlib/network.py
@@ -0,0 +1,103 @@
import tensorflow as tf
import numpy as np
import os
from PIL import Image
import random


class CNN(object):
def __init__(self, image_height, image_width, max_captcha, char_set, model_save_dir):
# 初始值
self.image_height = image_height
self.image_width = image_width
self.max_captcha = max_captcha
self.char_set = char_set
self.char_set_len = len(char_set)
self.model_save_dir = model_save_dir # 模型路径
with tf.name_scope('parameters'):
self.w_alpha = 0.01
self.b_alpha = 0.1
# tf初始化占位符
with tf.name_scope('data'):
self.X = tf.placeholder(tf.float32, [None, self.image_height * self.image_width]) # 特征向量
self.Y = tf.placeholder(tf.float32, [None, self.max_captcha * self.char_set_len]) # 标签
self.keep_prob = tf.placeholder(tf.float32) # dropout值

@staticmethod
def convert2gray(img):
"""
图片转为灰度图,如果是3通道图则计算,单通道图则直接返回
:param img:
:return:
"""
if len(img.shape) > 2:
r, g, b = img[:, :, 0], img[:, :, 1], img[:, :, 2]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray
else:
return img

def text2vec(self, text):
"""
转标签为oneHot编码
:param text: str
:return: numpy.array
"""
text_len = len(text)
if text_len > self.max_captcha:
raise ValueError('验证码最长{}个字符'.format(self.max_captcha))

vector = np.zeros(self.max_captcha * self.char_set_len)

for i, ch in enumerate(text):
idx = i * self.char_set_len + self.char_set.index(ch)
vector[idx] = 1
return vector

def model(self):
x = tf.reshape(self.X, shape=[-1, self.image_height, self.image_width, 1])
print(">>> input x: {}".format(x))

# 卷积层1
wc1 = tf.get_variable(name='wc1', shape=[3, 3, 1, 32], dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer())
bc1 = tf.Variable(self.b_alpha * tf.random_normal([32]))
conv1 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(x, wc1, strides=[1, 1, 1, 1], padding='SAME'), bc1))
conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
conv1 = tf.nn.dropout(conv1, self.keep_prob)

# 卷积层2
wc2 = tf.get_variable(name='wc2', shape=[3, 3, 32, 64], dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer())
bc2 = tf.Variable(self.b_alpha * tf.random_normal([64]))
conv2 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv1, wc2, strides=[1, 1, 1, 1], padding='SAME'), bc2))
conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
conv2 = tf.nn.dropout(conv2, self.keep_prob)

# 卷积层3
wc3 = tf.get_variable(name='wc3', shape=[3, 3, 64, 128], dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer())
bc3 = tf.Variable(self.b_alpha * tf.random_normal([128]))
conv3 = tf.nn.relu(tf.nn.bias_add(tf.nn.conv2d(conv2, wc3, strides=[1, 1, 1, 1], padding='SAME'), bc3))
conv3 = tf.nn.max_pool(conv3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
conv3 = tf.nn.dropout(conv3, self.keep_prob)
print(">>> convolution 3: ", conv3.shape)
next_shape = conv3.shape[1] * conv3.shape[2] * conv3.shape[3]

# 全连接层1
wd1 = tf.get_variable(name='wd1', shape=[next_shape, 1024], dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer())
bd1 = tf.Variable(self.b_alpha * tf.random_normal([1024]))
dense = tf.reshape(conv3, [-1, wd1.get_shape().as_list()[0]])
dense = tf.nn.relu(tf.add(tf.matmul(dense, wd1), bd1))
dense = tf.nn.dropout(dense, self.keep_prob)

# 全连接层2
wout = tf.get_variable('name', shape=[1024, self.max_captcha * self.char_set_len], dtype=tf.float32,
initializer=tf.contrib.layers.xavier_initializer())
bout = tf.Variable(self.b_alpha * tf.random_normal([self.max_captcha * self.char_set_len]))

with tf.name_scope('y_prediction'):
y_predict = tf.add(tf.matmul(dense, wout), bout)

return y_predict
76 changes: 76 additions & 0 deletions cnnlib/recognition_object.py
@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
"""
识别图像的类,为了快速进行多次识别可以调用此类下面的方法:
R = Recognizer(image_height, image_width, max_captcha)
for i in range(10):
r_img = Image.open(str(i) + ".jpg")
t = R.rec_image(r_img)
简单的图片每张基本上可以达到毫秒级的识别速度
"""
import tensorflow as tf
import numpy as np
from PIL import Image
from conf.sample import sample_conf
from cnnlib.network import CNN


class Recognizer(CNN):
def __init__(self, image_height, image_width, max_captcha, char_set, model_save_dir):
# 初始化变量
super(Recognizer, self).__init__(image_height, image_width, max_captcha, char_set, model_save_dir)

# 新建图和会话
self.g = tf.Graph()
self.sess = tf.Session(graph=self.g)
# 使用指定的图和会话
with self.g.as_default():
# 迭代循环前,写出所有用到的张量的计算表达式,如果写在循环中,会发生内存泄漏,拖慢识别的速度
# tf初始化占位符
self.X = tf.placeholder(tf.float32, [None, self.image_height * self.image_width]) # 特征向量
self.Y = tf.placeholder(tf.float32, [None, self.max_captcha * self.char_set_len]) # 标签
self.keep_prob = tf.placeholder(tf.float32) # dropout值
# 加载网络和模型参数
self.y_predict = self.model()
self.predict = tf.argmax(tf.reshape(self.y_predict, [-1, self.max_captcha, self.char_set_len]), 2)
saver = tf.train.Saver()
with self.sess.as_default() as sess:
saver.restore(sess, self.model_save_dir)

# def __del__(self):
# self.sess.close()
# print("session close")

def rec_image(self, img):
# 读取图片
img_array = np.array(img)
test_image = self.convert2gray(img_array)
test_image = test_image.flatten() / 255
# 使用指定的图和会话
with self.g.as_default():
with self.sess.as_default() as sess:
text_list = sess.run(self.predict, feed_dict={self.X: [test_image], self.keep_prob: 1.})

# 获取结果
predict_text = text_list[0].tolist()
p_text = ""
for p in predict_text:
p_text += str(self.char_set[p])

# 返回识别结果
return p_text


def main():
image_height = sample_conf["image_height"]
image_width = sample_conf["image_width"]
max_captcha = sample_conf["max_captcha"]
char_set = sample_conf["char_set"]
model_save_dir = sample_conf["model_save_dir"]
R = Recognizer(image_height, image_width, max_captcha, char_set, model_save_dir)
r_img = Image.open("./sample/test/2b3n_6915e26c67a52bc0e4e13d216eb62b37.jpg")
t = R.rec_image(r_img)
print(t)


if __name__ == '__main__':
main()
2 changes: 1 addition & 1 deletion test_batch.py
Expand Up @@ -7,7 +7,7 @@
from PIL import Image
import random
import os
from lib.network import CNN
from cnnlib.network import CNN


class TestError(Exception):
Expand Down
2 changes: 1 addition & 1 deletion train_model.py
Expand Up @@ -8,7 +8,7 @@
from PIL import Image
import random
import os
from lib.network import CNN
from cnnlib.network import CNN


class TrainError(Exception):
Expand Down
2 changes: 1 addition & 1 deletion webserver_recognize_api.py
Expand Up @@ -10,7 +10,7 @@
import json
from io import BytesIO
import os
from lib.recognition_object import Recognizer
from cnnlib.recognition_object import Recognizer

import time
from flask import Flask, request, jsonify, Response
Expand Down

0 comments on commit 1c84896

Please sign in to comment.