In [1]:
import os
import tensorflow as tf
from PIL import Image
from nets import nets_factory
import numpy as np

# os.environ['CUDA_VISIBLE_DEVICES']='0,1'

#不同的字符数量
CHAR_SET_LEN = 10
#图片高度
IMAGE_HEIGHT = 60
#图片宽度
IMAGE_WIDTH = 160
#批次
BATCH_SIZE = 25
#tfrecord文件存放路径
TFRECORD_FILE = "/root/tfrecord/train.tfrecord"

#placholder
x = tf.placeholder(tf.float32,[None,224,224])
y0 = tf.placeholder(tf.float32,[None])
y1 = tf.placeholder(tf.float32,[None])
y2 = tf.placeholder(tf.float32,[None])
y3 = tf.placeholder(tf.float32,[None])

#学习率
lr = tf.Variable(0.003,dtype=tf.float32 )

#从tfrecord读出数据
def read_and_decode(filename):
    #根据文件名生成一个队列
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    #f返回文件名和文件
    _,serialized_example  = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,features={
        'image':tf.FixedLenFeature([],tf.string),
        'label0':tf.FixedLenFeature([],tf.int64),
        'label1':tf.FixedLenFeature([],tf.int64),
        'label2':tf.FixedLenFeature([],tf.int64),
        'label3':tf.FixedLenFeature([],tf.int64),

    })

    #获取图片数据
    image = tf.decode_raw(features['image'],tf.uint8)
    #tf.train.shuffle_batch必须确定shape
    image = tf.reshape(image,[224,224])
    #图片预处理
    image = tf.cast(image,tf.float32)/255.0
    image = tf.subtract(image,0.5)
    image = tf.multiply(image,2.0)
    
    
    #获取label
    label0 = tf.cast(features['label0'],tf.int32)
    label1 = tf.cast(features['label1'],tf.int32)
    label2 = tf.cast(features['label2'],tf.int32)
    label3 = tf.cast(features['label3'],tf.int32)
    
    return image,label0,label1,label2,label3

    
    
    
    
    

    
    
    
    

In [2]:
#获取 图片数据和标签 
image,label0,label1,label2,label3 = read_and_decode(TFRECORD_FILE)
#使用shuffle_batch 可以随机打乱
image_batch,label_batch0,label_batch1,label_batch2,label_batch3=tf.train.shuffle_batch([image,label0,label1,label2,label3],
                                    batch_size=BATCH_SIZE,capacity=50000,min_after_dequeue=10000,num_threads=1)

#定义网络结构
train_network_fn = nets_factory.get_network_fn('alexnet_v2',num_classes=CHAR_SET_LEN,weight_decay=0.0005,is_training=True)

with tf.Session() as sess:
    
    #input :a tensor of size [batch_size,height,wight,channels]
    X = tf.reshape(x,[BATCH_SIZE,224,224,1])
    #数据输入网略 得到输出值
    logits0,logits1,logits2,logits3,end_points = train_network_fn(X)
    
    #把标签转化为 one_hot形式
    one_hot_label0 = tf.one_hot(indices=tf.cast(y0,tf.int32),depth=CHAR_SET_LEN)
    one_hot_label1 = tf.one_hot(indices=tf.cast(y1,tf.int32),depth=CHAR_SET_LEN)
    one_hot_label2 = tf.one_hot(indices=tf.cast(y2,tf.int32),depth=CHAR_SET_LEN)
    one_hot_label3 = tf.one_hot(indices=tf.cast(y3,tf.int32),depth=CHAR_SET_LEN)

    #计算loss
    loss0 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits0,labels=one_hot_label0))
    loss1 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits1,labels=one_hot_label1))
    loss2 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits2,labels=one_hot_label2))
    loss3 = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits3,labels=one_hot_label3))

    #计算总的loss
    total_loss = (loss0+loss1+loss2+loss3)/4.0
    #优化total_loss
    optimizer=tf.train.AdamOptimizer(learning_rate=lr).minimize(total_loss)
    
    
    
    #计算准确率
    correct_prediction0 = tf.equal(tf.argmax(one_hot_label0,1),tf.argmax(logits0,1))
    accuracy0 = tf.reduce_mean(tf.cast(correct_prediction0,tf.float32))
    
    correct_prediction1 = tf.equal(tf.argmax(one_hot_label1,1),tf.argmax(logits1,1))
    accuracy1 = tf.reduce_mean(tf.cast(correct_prediction1,tf.float32))
    
    correct_prediction2 = tf.equal(tf.argmax(one_hot_label2,1),tf.argmax(logits2,1))
    accuracy2 = tf.reduce_mean(tf.cast(correct_prediction2,tf.float32))
    
    correct_prediction3 = tf.equal(tf.argmax(one_hot_label3,1),tf.argmax(logits3,1))
    accuracy3 = tf.reduce_mean(tf.cast(correct_prediction3,tf.float32))
    
    
    saver =tf.train.Saver()
    
    sess.run(tf.global_variables_initializer())
    
    #创建一个协调器 管理线程
    coord = tf.train.Coordinator()
    #启动QueueRunner 此时文件名队列已经进队
    threads = tf.train.start_queue_runners(sess = sess,coord=coord)
    
    for i in range(6001):
        #获取一个批次的数据和标签
        b_image,b_label0,b_label1,b_label2,b_label3 = sess.run([image_batch,label_batch0,label_batch1,label_batch2,label_batch3])
        #优化模型 
        sess.run(optimizer,feed_dict={x:b_image,y0:b_label0,y1:b_label1,y2:b_label2,y3:b_label3})
        #没迭代20次计算一次loss和准确率
        if i%20 ==0:
            #没迭代 2000次降低一次学习率
            if i%2000==0:
                sess.run(tf.assign(lr,lr/3))
            acc0,acc1,acc2,acc3,loss_ = sess.run([accuracy0,accuracy1,accuracy2,accuracy3,total_loss],
                                                feed_dict={x:b_image,y0:b_label0,y1:b_label1,y2:b_label2,y3:b_label3})
            learning_rate = sess.run(lr)
            print("Iter:%d  Loss:%.3f  Accuracy:%.2f,%.2f,%.2f,%.2f  Learning_rate:%.4f" % (i,loss_,acc0,acc1,acc2,acc3,learning_rate))
            
            if i==6000:
                saver.save(sess,"/root/tfrecord/model_captcha/crack_captcha.model",global_step=1)
                break
                
                
                
    #通知其他线程关闭
    coord.request_stop()
    #其他所有线程关闭以后 这一函数 才能返回
    coord.join(threads)
    

Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensor_slices(string_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If `shuffle=False`, omit the `.shuffle(...)`.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensor_slices(input_tensor).shuffle(tf.shape(input_tensor, out_type=tf.int64)[0]).repeat(num_epochs)`. If `shuffle=False`, omit the `.shuffle(...)`.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.Dataset.from_tensors(tensor).repeat(num_epochs)`.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
Instructions for updating:
To construct input pipelines, use the `tf.data` module.
Instructions for updating:
Queue-based input pipelines have been replaced by `tf.data`. Use `tf.data.TFRecordDataset`.
Instructions for updating:
Queue-base

OutOfRangeError: RandomShuffleQueue '_1_shuffle_batch/random_shuffle_queue' is closed and has insufficient elements (requested 25, current size 4)
	 [[node shuffle_batch (defined at /usr/local/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py:1748) ]]

Original stack trace for 'shuffle_batch':
  File "/usr/local/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/local/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.6/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.6/site-packages/traitlets/config/application.py", line 664, in launch_instance
    app.start()
  File "/usr/local/lib/python3.6/site-packages/ipykernel/kernelapp.py", line 563, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.6/site-packages/tornado/platform/asyncio.py", line 148, in start
    self.asyncio_loop.run_forever()
  File "/usr/local/lib/python3.6/asyncio/base_events.py", line 421, in run_forever
    self._run_once()
  File "/usr/local/lib/python3.6/asyncio/base_events.py", line 1425, in _run_once
    handle._run()
  File "/usr/local/lib/python3.6/asyncio/events.py", line 127, in _run
    self._callback(*self._args)
  File "/usr/local/lib/python3.6/site-packages/tornado/ioloop.py", line 690, in <lambda>
    lambda f: self._run_callback(functools.partial(callback, future))
  File "/usr/local/lib/python3.6/site-packages/tornado/ioloop.py", line 743, in _run_callback
    ret = callback()
  File "/usr/local/lib/python3.6/site-packages/tornado/gen.py", line 787, in inner
    self.run()
  File "/usr/local/lib/python3.6/site-packages/tornado/gen.py", line 748, in run
    yielded = self.gen.send(value)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 377, in dispatch_queue
    yield self.process_one()
  File "/usr/local/lib/python3.6/site-packages/tornado/gen.py", line 225, in wrapper
    runner = Runner(result, future, yielded)
  File "/usr/local/lib/python3.6/site-packages/tornado/gen.py", line 714, in __init__
    self.run()
  File "/usr/local/lib/python3.6/site-packages/tornado/gen.py", line 748, in run
    yielded = self.gen.send(value)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 361, in process_one
    yield gen.maybe_future(dispatch(*args))
  File "/usr/local/lib/python3.6/site-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 268, in dispatch_shell
    yield gen.maybe_future(handler(stream, idents, msg))
  File "/usr/local/lib/python3.6/site-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/kernelbase.py", line 541, in execute_request
    user_expressions, allow_stdin,
  File "/usr/local/lib/python3.6/site-packages/tornado/gen.py", line 209, in wrapper
    yielded = next(result)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/ipkernel.py", line 300, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.6/site-packages/ipykernel/zmqshell.py", line 536, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2848, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/usr/local/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 2874, in _run_cell
    return runner(coro)
  File "/usr/local/lib/python3.6/site-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
    coro.send(None)
  File "/usr/local/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3051, in run_cell_async
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3242, in run_ast_nodes
    if (await self.run_code(code, result,  async_=asy)):
  File "/usr/local/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3319, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-2-e81e2f287685>", line 5, in <module>
    batch_size=BATCH_SIZE,capacity=50000,min_after_dequeue=10000,num_threads=1)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py", line 324, in new_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/input.py", line 1347, in shuffle_batch
    name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/training/input.py", line 874, in _shuffle_batch
    dequeued = queue.dequeue_many(batch_size, name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/ops/data_flow_ops.py", line 489, in dequeue_many
    self._queue_ref, n=n, component_types=self._dtypes, name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_data_flow_ops.py", line 3862, in queue_dequeue_many_v2
    timeout_ms=timeout_ms, name=name)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/framework/op_def_library.py", line 794, in _apply_op_helper
    op_def=op_def)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py", line 3357, in create_op
    attrs, op_def, compute_device)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py", line 3426, in _create_op_internal
    op_def=op_def)
  File "/usr/local/lib/python3.6/site-packages/tensorflow_core/python/framework/ops.py", line 1748, in __init__
    self._traceback = tf_stack.extract_stack()


本例主要是 运用 slim下的 alexnet网略 进行 多任务训练 得到识别 验证码的模型