In [55]:
import os                                                                          
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"                                       
os.environ["CUDA_VISIBLE_DEVICES"]="1"  
from tensorflow.contrib.slim.nets import resnet_v2, resnet_utils
import tensorflow as tf
from tensorflow.contrib import layers as layers_lib
from tensorflow.python.ops import variable_scope
from tensorflow.contrib.layers.python.layers import utils
from tensorflow.contrib import slim
from tensorflow.nn import ctc_loss, conv2d
import numpy as np
resnet_v2_block = resnet_v2.resnet_v2_block
resnet_v2 = resnet_v2.resnet_v2

import matplotlib.pyplot as plt

In [2]:
def resnet_v2_26_base(inputs,
                 num_classes=None,
                 is_training=False, # True - dur to update batchnorm layers
                 global_pool=False,
                 output_stride=1, # effective stride 
                 reuse=None,
                 include_root_block=False, #first conv layer. Removed due to max pool supression. We need large receprive field
                 scope='resnet_v2_26'):
  
    """
    Tensorflow resnet_v2 use only bottleneck blocks (consist of 3 layers).
    Thus, this resnet layer model consist of 26 layers.
    I put stride = 2 on each block due to increase receptive field.

    """
    blocks = [
      resnet_v2_block('block1', base_depth=64, num_units=2, stride=2),
      resnet_v2_block('block2', base_depth=128, num_units=2, stride=2),
      resnet_v2_block('block3', base_depth=256, num_units=2, stride=2),
      resnet_v2_block('block4', base_depth=512, num_units=2, stride=2),
    ]
    return resnet_v2(
      inputs,
      blocks,
      num_classes,
      is_training,
      global_pool,
      output_stride,
      include_root_block,
      reuse=reuse,
      scope=scope)

In [3]:
def make_ocr_net(inputs, num_classes=200, is_training=False):
    '''
    Creates neural network graph.
    Image width halved and it's define timestamps width
    No activation after output (no softmax), due to it's presence at ctc_loss() and beam_search().
    After resnet head features are resized to be [batch,1,width,channel], and after that goes 1x1 conv 
    to make anology of dense connaction for each timestamp.
    
    input: batch of images
    output: tensor of size [batch, time_stamps_width, num_classes]
    '''
    with tf.variable_scope('resnet_base', values=[inputs]) as sc:
        with slim.arg_scope([slim.conv2d],
                              activation_fn=None, normalizer_fn=None):
            net = resnet_utils.conv2d_same(inputs, 64, 7, stride=2, scope='conv1') #root conv for resnet
            #net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1') # due to enlarge of receptive field
            net = resnet_v2_26_base(net, output_stride=1, is_training = is_training)[0] # ouput is a tuple of last tensor and all tensors 
    with tf.variable_scope('class_head', values=[net]) as sc:
        net = tf.transpose(net, [0,3,1,2]) # next 4 lines due to column to channel reshape. [batch,c,h,w]
        _,c,h,_ = net.get_shape() # depth of input to conv op tensor should be static (defined)
        shape = tf.shape(net)
        net = tf.reshape(net, [shape[0], c*h, 1, shape[3]])
        net = tf.transpose(net,[0,2,3,1]) # back to [batch,h,w,c] = [batch,1,w,features*h]
        net = layers_lib.conv2d(net, num_classes, [1, 1], activation_fn=None) #CTC got softmax [batch,1,w,num_classes]
        net = tf.squeeze(net,1) #[batch,w,num_classes]
        return net

In [4]:
def ctc_loss_layer(sequence_labels, logits, sequence_length ):
    """
    Build CTC Loss layer for training
    sequence_length is a list of siquences lengths, len(sequence_length) = batch_size.
    In our case sequences can not be different size due to it origin of images batch, 
    which should be of equal size (e.g. padded)
    """
    loss = tf.nn.ctc_loss( sequence_labels, 
                           logits, 
                           sequence_length,
                           time_major=False,  # [batch_size, max_time, num_classes] for logits
                           ignore_longer_outputs_than_inputs=True )
    total_loss = tf.reduce_mean( loss )
    return total_loss

In [10]:
def get_training(sequence_labels, net_logits, sequence_length, 
                   learning_rate=1e-4, decay_steps=2**16, decay_rate=0.9, decay_staircase=False, 
                   momentum=0.9):
    """
    Set up training ops
    
    """
    with tf.name_scope( "train" ):
        net_logits_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)        
        loss = ctc_loss_layer(sequence_labels, net_logits, sequence_length) 
        # Update batch norm stats [http://stackoverflow.com/questions/43234667]
        extra_update_ops = tf.get_collection( tf.GraphKeys.UPDATE_OPS )
        with tf.control_dependencies( extra_update_ops ):
            # Calculate the learning rate given the parameters
            learning_rate_tensor = tf.train.exponential_decay(
                learning_rate,
                tf.train.get_global_step(),
                decay_steps,
                decay_rate,
                staircase=decay_staircase,
                name='learning_rate' )
            optimizer = tf.train.AdamOptimizer(
                learning_rate=learning_rate_tensor,
                beta1=momentum )
            train_op = tf.contrib.layers.optimize_loss(
                loss=loss,
                global_step=tf.train.get_global_step(),
                learning_rate=learning_rate_tensor, 
                optimizer=optimizer,
                variables=net_logits_vars)
            tf.summary.scalar( 'learning_rate', learning_rate_tensor )
    return train_op, loss

def get_prediction(output_net, seq_len, merge_repeated=False):
    '''
    predict by using beam search
    '''
    net = tf.transpose(output_net, [1, 0, 2]) #transpose to [time, batch, logits]
    decoded, prob = tf.nn.ctc_beam_search_decoder(net, seq_len, merge_repeated=merge_repeated)
    return decoded, prob

In [11]:
all_chars = ' !"%&()*+,-./0123456789:=ABCDEFGHIJKLMNOPQRSTUVWXYZ\\_abcdefghijklmnopqrstuvwxyz|~ЁАБВГДЕЖЗИЙКЛМНОПРСТУФХЦЧШЩЪЫЬЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё№'
char_to_indx = dict(zip(all_chars,range(len(all_chars))))
num_classes = len(all_chars)
def string_to_label(string):
    label = [char_to_indx[s] for s in string]
    return label

def batch_to_sparse(batch, dtype=np.int32): #batch of words
    '''
    function return sparce represantance of labels.
    input: batch - batch of words (List of words)
    output: indices - list of indexes [batch_num,time_stamp_num]
            values - list of char indexes shape [batch]
            shape - shape of dense batch represantation
    '''
    assert isinstance(batch, list) or isinstance(batch, np.ndarray), 'batch should be a list or numpy array of strings'
    indices = [] #[batch_num,w]
    values = [] # char indx
    for batch_num, word in enumerate(batch):
        assert isinstance(word,str), 'batch element should be a string'
        word_as_indx = string_to_label(word)
        indices.extend([(batch_num,char_num) for char_num, char in enumerate(word_as_indx)])
        values.extend([char for char_num, char in enumerate(word_as_indx)])
    indices = np.asarray(indices, dtype=dtype)
    values = np.asarray(values, dtype=dtype)
    shape = np.array([len(batch),indices.max(0)[1]+1], dtype=dtype)
    return indices, values, shape
def from_sparse_to_batch():
    pass

In [27]:
graph = tf.Graph()
with graph.as_default():
    input_image = tf.placeholder(shape=(None,32,None,3), dtype=tf.float32) #image heights should be 32
    input_label = tf.sparse_placeholder(tf.int32, name='label')

    output_net = make_ocr_net(input_image, num_classes = num_classes, is_training=False)
    tf.train.create_global_step()
    train_op, loss = get_training(input_label, output_net, [90])
    prediction = get_prediction(output_net, seq_len=[90], merge_repeated=False)
    init = tf.global_variables_initializer()

In [47]:
with tf.Session(graph=graph) as sess:
    tf.summary.FileWriter('log', sess.graph)
    sess.run(init)
    image_vert_line = np.zeros((1,32,180,3))
    image_vert_line[0,:,60,0] = 1
    batch = batch_to_sparse(['hello'])
    
    for _ in range(1000):
        _, loss_value = sess.run([train_op, loss], feed_dict = {input_image: image_vert_line, input_label: batch})
    print(loss_value)
    decoded, prob = sess.run(prediction, feed_dict = {input_image: image_vert_line})
    first_filters = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='resnet_base')[0]


9.8028095e-06
