In [None]:
%tensorflow_version 1.x

TensorFlow 1.x selected.


In [None]:
import tensorflow as tf
from tensorflow.contrib.layers import batch_norm
from tensorflow.contrib.framework import arg_scope
from matplotlib import pyplot as plt
from tensorflow.keras import regularizers
import numpy as np
import random
import sys
import copy
import re
import os
import time
import math
import pandas as pd
import pickle as pkl
import gzip
import cv2
import matplotlib.cm as cm
from skimage.transform import rescale, resize
rng = np.random.RandomState(int(time.time()))


NUM_CLASSES=130
BATCHSIZE=4
MODE='test'

##Only include while running on colab
#from google.colab import drive
#drive.mount('/gdrive', force_remount=True)

'''
Following three functions:
norm_weight(),
conv_norm_weight(),
ortho_weight()
are initialization methods for weights.
'''
def norm_weight(fan_in, fan_out):
	W_bound = np.sqrt(6.0 / (fan_in + fan_out))
	return np.asarray(rng.uniform(low=-W_bound, high=W_bound, size=(fan_in, fan_out)), dtype=np.float32)
 
def conv_norm_weight(nin, nout, kernel_size):
    filter_shape = (kernel_size[0], kernel_size[1], nin, nout)
    fan_in = kernel_size[0] * kernel_size[1] * nin
    fan_out = kernel_size[0] * kernel_size[1] * nout
    W_bound = np.sqrt(6. / (fan_in + fan_out))
    W = np.asarray(rng.uniform(low=-W_bound, high=W_bound, size=filter_shape), dtype=np.float32)
    return W.astype('float32')

def ortho_weight(ndim):
	W = np.random.randn(ndim, ndim)
	u, s, v = np.linalg.svd(W)
	return u.astype('float32')

'''
Watch_train class implements DenseNet encoder 
using bottleneck layers. Encoder contains three dense blocks. 
'''
class Watcher_train():
	def __init__(self, blocks,       # number of dense blocks
				level,                     # number of levels in each blocks
				growth_rate,               # growth rate in DenseNet paper: k
				training,
				dropout_rate=0.2,          # keep-rate of dropout layer
				dense_channels=0,          # filter numbers of transition layer's input
				transition=0.5,            # rate of comprssion
				input_conv_filters=48,     # filter numbers of conv2d before dense blocks
				input_conv_stride=2,       # stride of conv2d before dense blocks
				input_conv_kernel=[7,7]):  # kernel size of conv2d before dense blocks
		self.blocks = blocks
		self.growth_rate = growth_rate
		self.training = training
		self.dense_channels = dense_channels
		self.level = level
		self.dropout_rate = dropout_rate
		self.transition = transition
		self.input_conv_kernel = input_conv_kernel
		self.input_conv_stride = input_conv_stride
		self.input_conv_filters = input_conv_filters

	def bound(self, nin, nout, kernel):
		fin = nin * kernel[0] * kernel[1]
		fout = nout * kernel[0] * kernel[1]
		return np.sqrt(6. / (fin + fout))
    
	def dense_net(self, input_x, mask_x):

		#### before flowing into dense blocks ####
		input_x=tf.expand_dims(input=input_x, axis=3)
		x = input_x
		limit = self.bound(1, self.input_conv_filters, self.input_conv_kernel)
		x = tf.layers.conv2d(x, filters=self.input_conv_filters, strides=self.input_conv_stride,
			kernel_size=self.input_conv_kernel, padding='SAME', data_format='channels_last', use_bias=False, kernel_initializer=tf.random_uniform_initializer(-limit, limit, dtype=tf.float32))
		mask_x = mask_x[:, 0::2, 0::2]
		x = tf.layers.batch_normalization(x, training=self.training, momentum=0.9, scale=True, gamma_initializer=tf.random_uniform_initializer(-1.0/math.sqrt(self.input_conv_filters),
				1.0/math.sqrt(self.input_conv_filters), dtype=tf.float32), epsilon=0.0001)
		x = tf.nn.relu(x)
		x = tf.layers.max_pooling2d(inputs=x, pool_size=[2,2], strides=2, padding='SAME')
		
		input_pre = x
		mask_x = mask_x[:, 0::2, 0::2]
		self.dense_channels += self.input_conv_filters
		dense_out = x

		#### flowing into dense blocks and transition_layer ####
		for i in range(self.blocks):
			for j in range(self.level):

				#### [1, 1] convolution part for bottleneck ####
				limit = self.bound(self.dense_channels, 4 * self.growth_rate, [1,1])
				x = tf.layers.conv2d(x, filters=4 * self.growth_rate, kernel_size=[1,1],
					strides=1, padding='VALID', data_format='channels_last', use_bias=False, kernel_initializer=tf.random_uniform_initializer(-limit, limit, dtype=tf.float32))
				x = tf.layers.batch_normalization(inputs=x,  training=self.training, momentum=0.9, scale=True, gamma_initializer=tf.random_uniform_initializer(-1.0/math.sqrt(4 * self.growth_rate),
					1.0/math.sqrt(4 * self.growth_rate), dtype=tf.float32), epsilon=0.0001)
				x = tf.nn.relu(x)
				x = tf.layers.dropout(inputs=x, rate=self.dropout_rate, training=self.training)

				#### [3, 3] convolution part for regular convolve operation
				limit = self.bound(4 * self.growth_rate, self.growth_rate, [3,3])
				x = tf.layers.conv2d(x, filters=self.growth_rate, kernel_size=[3,3],
					strides=1, padding='SAME', data_format='channels_last', use_bias=False, kernel_initializer=tf.random_uniform_initializer(-limit, limit, dtype=tf.float32))
				x = tf.layers.batch_normalization(inputs=x, training=self.training, momentum=0.9, scale=True, gamma_initializer=tf.random_uniform_initializer(-1.0/math.sqrt(self.growth_rate),
					1.0/math.sqrt(self.growth_rate), dtype=tf.float32), epsilon=0.0001)
				x = tf.nn.relu(x)

				x = tf.layers.dropout(inputs=x, rate=self.dropout_rate, training=self.training)

				dense_out = tf.concat([dense_out, x], axis=3)
				x = dense_out
				#### calculate the filter number of dense block's output ####
				self.dense_channels += self.growth_rate

			if i < self.blocks - 1:
				compressed_channels = int(self.dense_channels * self.transition)

				#### new dense channels for new dense block ####
				self.dense_channels = compressed_channels
				limit = self.bound(self.dense_channels, compressed_channels, [1,1])
				x = tf.layers.conv2d(x, filters=compressed_channels, kernel_size=[1,1],
					strides=1, padding='VALID', data_format='channels_last', use_bias=False, kernel_initializer=tf.random_uniform_initializer(-limit, limit, dtype=tf.float32))
				x = tf.layers.batch_normalization(x, training=self.training, momentum=0.9, scale=True, gamma_initializer=tf.random_uniform_initializer(-1.0/math.sqrt(self.dense_channels),
						1.0/math.sqrt(self.dense_channels), dtype=tf.float32), epsilon=0.0001)
				x = tf.nn.relu(x)
				x = tf.layers.dropout(inputs=x, rate=self.dropout_rate, training=self.training)
    
				x = tf.layers.average_pooling2d(inputs=x, pool_size=[2,2], strides=2, padding='SAME')
				dense_out = x
				mask_x = mask_x[:, 0::2, 0::2]

		return dense_out, mask_x

'''
Attender class implements contextual attention mechanism. 
'''
class Attender():
	def __init__(self, channels,                          # output of Watcher | [batch, h, w, channels]
				dim_decoder, dim_attend):                       # decoder hidden state:$h_{t-1}$ | [batch, dec_dim]

		self.channels = channels

		self.coverage_kernel = [11,11]                      # kernel size of $Q$
		self.coverage_filters = dim_attend                  # filter numbers of $Q$ | 512

		self.dim_decoder = dim_decoder                      # 256
		self.dim_attend = dim_attend                        # unified dim of three parts calculating $e_ti$ i.e.
		                                                    # $Q*beta_t$, $U_a * a_i$, $W_a x h_{t-1}$ | 512
		self.U_f = tf.Variable(norm_weight(self.coverage_filters, self.dim_attend), name='U_f') # $U_f x f_i$ | [cov_filters, dim_attend]
		self.U_f_b = tf.Variable(np.zeros((self.dim_attend,)).astype('float32'), name='U_f_b')  # $U_f x f_i + U_f_b$ | [dim_attend, ]

		self.U_a = tf.Variable(norm_weight(self.channels,self.dim_attend), name='U_a')         # $U_a x a_i$ | [annotatin_channels, dim_attend]
		self.U_a_b = tf.Variable(np.zeros((self.dim_attend,)).astype('float32'), name='U_a_b') # $U_a x a_i + U_a_b$ | [dim_attend, ]

		self.W_a = tf.Variable(norm_weight(self.dim_decoder,self.dim_attend), name='W_a')      # $W_a x h_{t_1}$ | [dec_dim, dim_attend]
		self.W_a_b = tf.Variable(np.zeros((self.dim_attend,)).astype('float32'), name='W_a_b') # $W_a x h_{t-1} + W_a_b$ | [dim_attend, ]

		self.V_a = tf.Variable(norm_weight(self.dim_attend, 1), name='V_a')                    # $V_a x tanh(A + B + C)$ | [dim_attend, 1]
		self.V_a_b = tf.Variable(np.zeros((1,)).astype('float32'), name='V_a_b')               # $V_a x tanh(A + B + C) + V_a_b$ | [1, ]

		self.alpha_past_filter = tf.Variable(conv_norm_weight(1, self.dim_attend, self.coverage_kernel), name='alpha_past_filter')


	def get_context(self, annotation4ctx, h_t_1, alpha_past4ctx, a_mask):

		#### calculate $U_f x f_i$ ####
		alpha_past_4d = alpha_past4ctx[:, :, :, None]

		Ft = tf.nn.conv2d(alpha_past_4d, filter=self.alpha_past_filter, strides=[1, 1, 1, 1], padding='SAME')
		coverage_vector = tf.tensordot(Ft, self.U_f, axes=1) 	+ self.U_f_b    # [batch, h, w, dim_attend]

		#### calculate $U_a x a_i$ ####
		watch_vector = tf.tensordot(annotation4ctx, self.U_a, axes=1) + self.U_a_b   # [batch, h, w, dim_attend]

		#### calculate $W_a x h_{t - 1}$ ####
		speller_vector = tf.tensordot(h_t_1, self.W_a, axes=1) + self.W_a_b   # [batch, dim_attend]
		speller_vector = speller_vector[:, None, None, :]    # [batch, None, None, dim_attend]

		tanh_vector = tf.tanh(coverage_vector + watch_vector + speller_vector)    # [batch, h, w, dim_attend]
		e_ti = tf.tensordot(tanh_vector, self.V_a, axes=1) + self.V_a_b  # [batch, h, w, 1]
		alpha = tf.exp(e_ti)
		alpha = tf.squeeze(alpha, axis=3)

		if a_mask is not None:
			alpha = alpha * a_mask

		alpha = alpha / tf.reduce_sum(alpha, axis=[1, 2], keepdims=True)    # normlized weights | [batch, h, w]
		alpha_past4ctx += alpha    # accumalated weights matrix | [batch, h, w]
		context = tf.reduce_sum(annotation4ctx * alpha[:, :, :, None], axis=[1, 2])   # context vector | [batch, feature_channels]
		return context, alpha, alpha_past4ctx

'''
Parser class implements 2 layerd Decoder (GRU) which decodes an input image 
and outputs a seuence of characters using attention mechanism . 
'''
class Parser():
	def __init__(self, hidden_dim, word_dim, attender, context_dim):

		self.attender = attender                                # inner-instance of Attender to provide context
		self.context_dim = context_dim                          # context dime 684
		self.hidden_dim = hidden_dim                            # dim of hidden state  256
		self.word_dim = word_dim                                # dim of embedding word 256
		
		##GRU 1 weights initialization starts here
		self.W_yz_yr = tf.Variable(np.concatenate(
			[norm_weight(self.word_dim, self.hidden_dim), norm_weight(self.word_dim, self.hidden_dim)], axis=1), name='W_yz_yr') # [dim_word, 2 * dim_decoder]
		self.b_yz_yr = tf.Variable(np.zeros((2 * self.hidden_dim, )).astype('float32'), name='b_yz_yr')

		self.U_hz_hr = tf.Variable(np.concatenate(
			[ortho_weight(self.hidden_dim),ortho_weight(self.hidden_dim)], axis=1), name='U_hz_hr')                              # [dim_hidden, 2 * dim_hidden]

		self.W_yh = tf.Variable(norm_weight(self.word_dim,
			self.hidden_dim), name='W_yh')
		self.b_yh = tf.Variable(np.zeros((self.hidden_dim, )).astype('float32'), name='b_yh')                                    # [dim_decoder, ]

		self.U_rh = tf.Variable(ortho_weight(self.hidden_dim), name='U_rh')                                                      # [dim_hidden, dim_hidden]

		##GRU 2 weights initialization starts here
		self.U_hz_hr_nl = tf.Variable(np.concatenate(
			[ortho_weight(self.hidden_dim), ortho_weight(self.hidden_dim)], axis=1), name='U_hz_hr_nl')                          # [dim_hidden, 2 * dim_hidden] non_linear

		self.b_hz_hr_nl = tf.Variable(np.zeros((2 * self.hidden_dim, )).astype('float32'), name='b_hz_hr_nl')                    # [2 * dim_hidden, ]

		self.W_c_z_r = tf.Variable(norm_weight(self.context_dim,
			2 * self.hidden_dim), name='W_c_z_r')

		self.U_rh_nl = tf.Variable(ortho_weight(self.hidden_dim), name='U_rh_nl')
		self.b_rh_nl = tf.Variable(np.zeros((self.hidden_dim, )).astype('float32'), name='b_rh_nl')

		self.W_c_h_nl = tf.Variable(norm_weight(self.context_dim, self.hidden_dim), name='W_c_h_nl')



	def get_ht_ctx(self, emb_y, target_hidden_state_0, annotations, a_m, y_m):

		res = tf.scan(self.one_time_step, elems=(emb_y, y_m),
			initializer=(target_hidden_state_0,
				tf.zeros([tf.shape(annotations)[0], self.context_dim]),
				tf.zeros([tf.shape(annotations)[0], tf.shape(annotations)[1], tf.shape(annotations)[2]]),
				tf.zeros([tf.shape(annotations)[0], tf.shape(annotations)[1], tf.shape(annotations)[2]]),
				annotations, a_m))

		return res




	def one_time_step(self, tuple_h0_ctx_alpha_alpha_past_annotation, tuple_emb_mask):

		target_hidden_state_0 = tuple_h0_ctx_alpha_alpha_past_annotation[0]
		alpha_past_one        = tuple_h0_ctx_alpha_alpha_past_annotation[3]
		annotation_one        = tuple_h0_ctx_alpha_alpha_past_annotation[4]
		a_mask                = tuple_h0_ctx_alpha_alpha_past_annotation[5]

		emb_y, y_mask = tuple_emb_mask

		#GRU 1 starts here
		emb_y_z_r_vector = tf.tensordot(emb_y, self.W_yz_yr, axes=1) + \
		self.b_yz_yr                                            # [batch, 2 * dim_decoder]
		hidden_z_r_vector = tf.tensordot(target_hidden_state_0,
		self.U_hz_hr, axes=1)                                   # [batch, 2 * dim_decoder]
		pre_z_r_vector = tf.sigmoid(emb_y_z_r_vector + \
		hidden_z_r_vector)                                      # [batch, 2 * dim_decoder]

		r1 = pre_z_r_vector[:, :self.hidden_dim]                # [batch, dim_decoder]
		z1 = pre_z_r_vector[:, self.hidden_dim:]                # [batch, dim_decoder]

		emb_y_h_vector = tf.tensordot(emb_y, self.W_yh, axes=1) + \
		self.b_yh                                               # [batch, dim_decoder]
		hidden_r_h_vector = tf.tensordot(target_hidden_state_0,
		self.U_rh, axes=1)                                      # [batch, dim_decoder]
		hidden_r_h_vector *= r1
		pre_h_proposal = tf.tanh(hidden_r_h_vector + emb_y_h_vector)

		pre_h = z1 * target_hidden_state_0 + (1. - z1) * pre_h_proposal

		if y_mask is not None:
			pre_h = y_mask[:, None] * pre_h + (1. - y_mask)[:, None] * target_hidden_state_0

		context, alpha, alpha_past_one = self.attender.get_context(annotation_one, pre_h, alpha_past_one, a_mask)  # [batch, dim_ctx]
		
		#GRU 2 starts here
		emb_y_z_r_nl_vector = tf.tensordot(pre_h, self.U_hz_hr_nl, axes=1) + self.b_hz_hr_nl
		context_z_r_vector = tf.tensordot(context, self.W_c_z_r, axes=1)
		z_r_vector = tf.sigmoid(emb_y_z_r_nl_vector + context_z_r_vector)

		r2 = z_r_vector[:, :self.hidden_dim]
		z2 = z_r_vector[:, self.hidden_dim:]

		emb_y_h_nl_vector = tf.tensordot(pre_h, self.U_rh_nl, axes=1) 
		emb_y_h_nl_vector *= r2
		emb_y_h_nl_vector=emb_y_h_nl_vector+ self.b_rh_nl # bias added after point wise multiplication with r2
		context_h_vector = tf.tensordot(context, self.W_c_h_nl, axes=1)
		h_proposal = tf.tanh(emb_y_h_nl_vector + context_h_vector)
		h = z2 * pre_h + (1. - z2) * h_proposal

		if y_mask is not None:
			h = y_mask[:, None] * h + (1. - y_mask)[:, None] * pre_h

		return h, context, alpha, alpha_past_one, annotation_one, a_mask


'''
WAP class is the main class. This class uses below three classes:
1) Watch_train (Encoder)
2) Attender (Contextual attention mechnism)
3) Parser (2 layerd GRU Decoder)
WAP class implements two functions get_cost and get_sample, which are actually used for cost calculation and decoding.
'''
class WAP():
	def __init__(self, watcher, attender, parser, hidden_dim, word_dim, context_dim, target_dim, training):

		#self.batch_size = batch_size
		self.hidden_dim = hidden_dim
		self.word_dim = word_dim
		self.context_dim = context_dim
		self.target_dim = target_dim
		self.embed_matrix = tf.Variable(norm_weight(self.target_dim, self.word_dim), name='embed')

		self.watcher = watcher
		self.attender = attender
		self.parser = parser
		self.Wa2h = tf.Variable(norm_weight(self.context_dim, self.hidden_dim), name='Wa2h')
		self.ba2h = tf.Variable(np.zeros((self.hidden_dim,)).astype('float32'), name='ba2h')
		self.Wc = tf.Variable(norm_weight(self.context_dim, self.word_dim), name='Wc')
		self.bc = tf.Variable(np.zeros((self.word_dim,)).astype('float32'), name='bc')
		self.Wh = tf.Variable(norm_weight(self.hidden_dim, self.word_dim), name='Wh')
		self.bh = tf.Variable(np.zeros((self.word_dim,)).astype('float32'), name='bh')
		self.Wy = tf.Variable(norm_weight(self.word_dim, self.word_dim), name='Wy')
		self.by = tf.Variable(np.zeros((self.word_dim,)).astype('float32'), name='by')
		self.Wo = tf.Variable(norm_weight(self.word_dim//2, self.target_dim), name='Wo')
		self.bo = tf.Variable(np.zeros((self.target_dim,)).astype('float32'), name='bo')
		self.training = training


	def get_cost(self, cost_annotation, cost_y, a_m, y_m,alpha_reg):
		
		#### step: 1 prepration of embedding of labels sequences #### 
		timesteps = tf.shape(cost_y)[0]
		batch_size = tf.shape(cost_y)[1]
		emb_y = tf.nn.embedding_lookup(self.embed_matrix, tf.reshape(cost_y, [-1]))
		emb_y = tf.reshape(emb_y, [timesteps, batch_size, self.word_dim])
		emb_pad = tf.fill((1, batch_size, self.word_dim), 0.0)
		emb_shift = tf.concat([emb_pad ,tf.strided_slice(emb_y, [0, 0, 0], [-1, batch_size, self.word_dim], [1, 1, 1])], axis=0)
		new_emb_y = emb_shift

		#### step: 2 calculation of h_0 #### 
		anno_mean = tf.reduce_sum(cost_annotation * a_m[:, :, :, None], axis=[1, 2]) / tf.reduce_sum(a_m, axis=[1, 2])[:, None]
		h_0 = tf.tensordot(anno_mean, self.Wa2h, axes=1) + self.ba2h  # [batch, hidden_dim]
		h_0 = tf.tanh(h_0)
	
		#### step: 3 calculation of h_t and c_t at all time steps #### 
		ret = self.parser.get_ht_ctx(new_emb_y, h_0, cost_annotation, a_m, y_m)
		h_t = ret[0]                      # h_t of all timesteps [timesteps, batch, hidden_dim]
		c_t = ret[1]                      # c_t of all timesteps [timesteps, batch, context_dim]
		alpha=ret[2]											# alpha of all timesteps [timesteps, batch, h, w]

		#### step: 4 calculation of cost using h_t, c_t and y_t_1 ####
		y_t_1 = new_emb_y                 # shifted y | [1:] = [:-1]
		logit_gru = tf.tensordot(h_t, self.Wh, axes=1) + self.bh
		logit_ctx = tf.tensordot(c_t, self.Wc, axes=1) + self.bc
		logit_pre = tf.tensordot(y_t_1, self.Wy, axes=1) + self.by
		logit = logit_pre + logit_ctx + logit_gru
		shape = tf.shape(logit)
		logit = tf.reshape(logit, [shape[0], -1, shape[2]//2, 2])
		logit = tf.reduce_max(logit, axis=3)
		logit = tf.layers.dropout(inputs=logit, rate=0.2, training=self.training)
		logit = tf.tensordot(logit, self.Wo, axes=1) + self.bo
		logit_shape = tf.shape(logit)
		logit = tf.reshape(logit, [-1,logit_shape[2]])
		cost = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logit, labels=tf.one_hot(tf.reshape(cost_y, [-1]),depth=self.target_dim))
		
		#### max pooling on vector with size equal to word_dim ####
		cost = tf.multiply(cost, tf.reshape(y_m, [-1]))
		cost = tf.reshape(cost, [shape[0], shape[1]])
		cost = tf.reduce_sum(cost, axis=0)
		cost = tf.reduce_mean(cost)
	
		#### alpha  L1 regularization ####
		alpha_sum=tf.reduce_mean(tf.reduce_sum(tf.reduce_sum(tf.abs(alpha), axis=[2, 3]),axis=0))			
		cost = tf.cond(alpha_reg > 0,  	lambda: cost + (alpha_reg * alpha_sum), lambda: cost)

		return cost


	'''
	The get_word function is called from get_sample function. At first time calll, 
	batch size is equal to 1, later it is called with batch size equal to live_k. 
	'''
	def get_word(self, sample_y, sample_h_pre, alpha_past_pre, sample_annotation):

		emb = tf.cond(sample_y[0] < 0,
			lambda: tf.fill((1, self.word_dim), 0.0),
			lambda: tf.nn.embedding_lookup(self.embed_matrix, sample_y)
			)

		#ret = self.parser.one_time_step((h_pre, None, None, alpha_past_pre, annotation, None), (emb, None))
		emb_y_z_r_vector = tf.tensordot(emb, self.parser.W_yz_yr, axes=1) + \
		self.parser.b_yz_yr                                            # [batch, 2 * dim_decoder]
		hidden_z_r_vector = tf.tensordot(sample_h_pre,
		self.parser.U_hz_hr, axes=1)                                   # [batch, 2 * dim_decoder]
		pre_z_r_vector = tf.sigmoid(emb_y_z_r_vector + \
		hidden_z_r_vector)                                             # [batch, 2 * dim_decoder]

		r1 = pre_z_r_vector[:, :self.parser.hidden_dim]                # [batch, dim_decoder]
		z1 = pre_z_r_vector[:, self.parser.hidden_dim:]                # [batch, dim_decoder]

		emb_y_h_vector = tf.tensordot(emb, self.parser.W_yh, axes=1) + \
		self.parser.b_yh                                               # [batch, dim_decoder]
		hidden_r_h_vector = tf.tensordot(sample_h_pre,
		self.parser.U_rh, axes=1)                                      # [batch, dim_decoder]
		hidden_r_h_vector *= r1
		pre_h_proposal = tf.tanh(hidden_r_h_vector + emb_y_h_vector)

		pre_h = z1 * sample_h_pre + (1. - z1) * pre_h_proposal

		context, _, alpha_past = self.parser.attender.get_context(sample_annotation, pre_h, alpha_past_pre, None)  # [batch, dim_ctx]
		emb_y_z_r_nl_vector = tf.tensordot(pre_h, self.parser.U_hz_hr_nl, axes=1) + self.parser.b_hz_hr_nl
		context_z_r_vector = tf.tensordot(context, self.parser.W_c_z_r, axes=1)
		z_r_vector = tf.sigmoid(emb_y_z_r_nl_vector + context_z_r_vector)

		r2 = z_r_vector[:, :self.parser.hidden_dim]
		z2 = z_r_vector[:, self.parser.hidden_dim:]

		emb_y_h_nl_vector = tf.tensordot(pre_h, self.parser.U_rh_nl, axes=1) + self.parser.b_rh_nl
		emb_y_h_nl_vector *= r2
		context_h_vector = tf.tensordot(context, self.parser.W_c_h_nl, axes=1)
		h_proposal = tf.tanh(emb_y_h_nl_vector + context_h_vector)
		h = z2 * pre_h + (1. - z2) * h_proposal

		h_t = h
		c_t = context
		alpha_past_t = alpha_past
		y_t_1 = emb
		logit_gru = tf.tensordot(h_t, self.Wh, axes=1) + self.bh
		logit_ctx = tf.tensordot(c_t, self.Wc, axes=1) + self.bc
		logit_pre = tf.tensordot(y_t_1, self.Wy, axes=1) + self.by
		logit = logit_pre + logit_ctx + logit_gru   # batch x word_dim

		#### max pooling on vector with size equal to word_dim ####
		shape = tf.shape(logit)
		logit = tf.reshape(logit, [-1, shape[1]//2, 2])
		logit = tf.reduce_max(logit, axis=2)

		logit = tf.layers.dropout(inputs=logit, rate=0.2, training=self.training)

		logit = tf.tensordot(logit, self.Wo, axes=1) + self.bo

		next_probs = tf.nn.softmax(logits=logit)
		next_word  = tf.reduce_max(tf.multinomial(next_probs, num_samples=1), axis=1)
		return next_probs, next_word, h_t, alpha_past_t

	'''
	Calculates sequence of labels/characters from annotations using beam search if stochastic is set to false 
	'''

	def get_sample(self,p, w, h, alpha, ctx0, h_0, k , maxlen, stochastic, session, training):

		sample = []
		sample_score = []

		live_k = 1
		dead_k = 0

		hyp_samples = [[]] * 1
		hyp_scores = np.zeros(live_k).astype('float32')
		hyp_states = []


		next_alpha_past = np.zeros((ctx0.shape[0], ctx0.shape[1], ctx0.shape[2])).astype('float32')
		emb_0 = np.zeros((ctx0.shape[0], 256))

		next_w = -1 * np.ones((1,)).astype('int64')

		next_state = h_0
		for ii in range(maxlen):

			ctx = np.tile(ctx0, [live_k, 1, 1, 1])

			input_dict = {
			anno:ctx,
			infer_y:next_w,
			alpha_past:next_alpha_past,
			h_pre:next_state,
			if_trainning:training
			}
      
			next_p, next_w, next_state, next_alpha_past = session.run([p, w, h, alpha], feed_dict=input_dict)
			if stochastic:
				
				nw = next_w[0]
				sample.append(nw)
				sample_score += next_p[0, nw]
				if nw == 0:
					break
			else:
				cand_scores = hyp_scores[:, None] - np.log(next_p)
				cand_flat = cand_scores.flatten()
				ranks_flat = cand_flat.argsort()[:(k-dead_k)]
				voc_size = next_p.shape[1]

				assert voc_size==NUM_CLASSES

				trans_indices = ranks_flat // voc_size  # trans_indices are used to represent to different beams, values lies from k-dead_k
				word_indices = ranks_flat % voc_size    # word_indices are used to represent to different label, values lies from 0-voc_size
				costs = cand_flat[ranks_flat]
				new_hyp_samples = []
				new_hyp_scores = np.zeros(k-dead_k).astype('float32')
				new_hyp_states = []
				new_hyp_alpha_past = []

				for idx, [ti, wi] in enumerate(zip(trans_indices, word_indices)):
					new_hyp_samples.append(hyp_samples[ti]+[wi])    # concatenates [wi] list with list referring to beam ti 
					new_hyp_scores[idx] = copy.copy(costs[idx])
					new_hyp_states.append(copy.copy(next_state[ti]))
					new_hyp_alpha_past.append(copy.copy(next_alpha_past[ti]))

				new_live_k = 0
				hyp_samples = []
				hyp_scores = []
				hyp_states = []
				hyp_alpha_past = []

				for idx in range(len(new_hyp_samples)):
					if new_hyp_samples[idx][-1] == 0: # <eol>
						sample.append(new_hyp_samples[idx])
						sample_score.append(new_hyp_scores[idx])
						dead_k += 1
					else:
						new_live_k += 1
						hyp_samples.append(new_hyp_samples[idx])
						hyp_scores.append(new_hyp_scores[idx])
						hyp_states.append(new_hyp_states[idx])
						hyp_alpha_past.append(new_hyp_alpha_past[idx])
				hyp_scores = np.array(hyp_scores)
				live_k = new_live_k

				if new_live_k < 1:
					break
				if dead_k >= k:
					break

				next_w = np.array([w1[-1] for w1 in hyp_samples])
				next_state = np.array(hyp_states)
				next_alpha_past = np.array(hyp_alpha_past)
        


		if not stochastic:
			# dump every remaining one
			if live_k > 0:
				for idx in range(live_k):
					sample.append(hyp_samples[idx])
					sample_score.append(hyp_scores[idx])

		return sample, sample_score,next_alpha_past

def cmp_result(label,rec):
    dist_mat = np.zeros((len(label)+1, len(rec)+1),dtype='int32')
    dist_mat[0,:] = range(len(rec) + 1)
    dist_mat[:,0] = range(len(label) + 1)
    for i in range(1, len(label) + 1):
        for j in range(1, len(rec) + 1):
            hit_score = dist_mat[i-1, j-1] + (label[i-1] != rec[j-1])
            ins_score = dist_mat[i,j-1] + 1
            del_score = dist_mat[i-1, j] + 1
            dist_mat[i,j] = min(hit_score, ins_score, del_score)

    dist = dist_mat[len(label), len(rec)]
    return dist, len(label)
def convert(list): 
      
    # Converting integer list to string list 
    s = [str(i) for i in list] 
      
    # Join list items using join() 
    res = "".join(s)
      
    return(res) 
def process_chr_error(recfile, labelfile, resultfile):
    total_dist = 0
    total_label = 0
    total_line = 0
    total_line_rec = 0
    rec_mat = {}
    label_mat = {}
    cc=1
    with open(recfile) as f_rec:
        for line in f_rec:
            tmp = line.split()
            key = tmp[0]
            latex = tmp[1:]
            rec_mat[key] = latex
            cc=cc+1
    cc=1
    with open(labelfile) as f_label:
        for line in f_label:
            tmp = line.split()
            key = tmp[0]
            latex = tmp[1:]
            label_mat[key] = latex
            cc=cc+1
    for key_rec in rec_mat:
        label = label_mat[key_rec]
        rec = rec_mat[key_rec]
        dist, llen = cmp_result(label, rec)
        total_dist += dist
        total_label += llen
        total_line += 1
        if dist == 0:
            total_line_rec += 1
    chr_error = float(total_dist)/total_label
    line_rec = float(total_line_rec)/total_line
    return chr_error, line_rec
def process_wer_error(recfile, labelfile, resultfile):
    total_dist = 0
    total_label = 0
    total_line = 0
    total_line_rec = 0
    rec_mat = {}
    label_mat = {}
    cc=1
    with open(recfile) as f_rec:
        for line in f_rec:
            tmp = line.split()
            key = tmp[0]
            latex = tmp[1:]
            ss=convert(latex)
            latex = ss.split("4")
            rec_mat[key] = latex
            cc=cc+1
    cc=1
    with open(labelfile) as f_label:
        for line in f_label:
            tmp = line.split()
            key = tmp[0]
            latex = tmp[1:]
            ss=convert(latex)
            latex = ss.split("4")
            label_mat[key] = latex
            cc=cc+1
    for key_rec in rec_mat:
        label = label_mat[key_rec]
        rec = rec_mat[key_rec]
        dist, llen = cmp_result(label, rec)
        total_dist += dist
        total_label += llen
       
    wer = float(total_dist)/total_label
    return wer
def process(recfile, labelfile, resultfile):
		[cer,ler]=process_chr_error(recfile, labelfile, resultfile)
		wer=process_wer_error(recfile, labelfile, resultfile)
		f_result = open(resultfile,'w')
		f_result.write('CER {}\n'.format(cer))
		f_result.write('WER {}\n'.format(wer))
		f_result.write('LER {}\n'.format(ler))
		f_result.close()

def dataIterator(feature_file,label_file,batch_size,batch_Imagesize,maxlen,maxImagesize):
    
    fp=open(feature_file,'rb')
    features=pkl.load(fp)
    fp.close()

    fp2=open(label_file,'rb')
    labels=pkl.load(fp2)
    fp2.close()
    

    imageSize={}
    for uid,fea in features.items():
        
        imageSize[uid]=fea.shape[0]*fea.shape[1]

    imageSize= sorted(imageSize.items(), key=lambda d:d[1]) # sorted by sentence length,  return a list with each triple element

    feature_batch=[]
    label_batch=[]
    feature_total=[]
    label_total=[]
    uidList=[]

    batch_image_size=0
    biggest_image_size=0
    i=0
    for uid,size in imageSize:
        if size>biggest_image_size:
            biggest_image_size=size
        fea=features[uid]
        lab=labels[uid]
        batch_image_size=biggest_image_size*(i+1)
        if len(lab)>maxlen:
            print('sentence', uid, 'length bigger than', maxlen, 'ignore')
        elif size>maxImagesize:
            print('image', uid, 'size bigger than', maxImagesize, 'ignore')
        else:
            uidList.append(uid)
            if batch_image_size>batch_Imagesize or i==batch_size: # a batch is full
                feature_total.append(feature_batch)
                label_total.append(label_batch)

                i=0
                biggest_image_size=size
                feature_batch=[]
                label_batch=[]
                feature_batch.append(fea)
                label_batch.append(lab)
                batch_image_size=biggest_image_size*(i+1)
                i+=1
            else:
                feature_batch.append(fea)
                label_batch.append(lab)
                i+=1

    # last batch
    feature_total.append(feature_batch)
    label_total.append(label_batch)

    print('total ',len(feature_total), 'batch data loaded')

    return list(zip(feature_total,label_total)),uidList

def prepare_data(images_x, seqs_y, n_words_src=30000,
                 n_words=30000):

    heights_x = [s.shape[0] for s in images_x]
    widths_x = [s.shape[1] for s in images_x]
    lengths_y = [len(s) for s in seqs_y]

    n_samples = len(heights_x)
    max_height_x = np.max(heights_x)
    max_width_x = np.max(widths_x)
    maxlen_y = np.max(lengths_y) + 1
    x = np.zeros((n_samples, max_height_x, max_width_x)).astype('float32')
    
    y =np.zeros((maxlen_y, n_samples)).astype('int64') # the <eol> must be 0 in the dict !!!
    x_mask =np.zeros((n_samples, max_height_x, max_width_x)).astype('float32')
    y_mask = np.zeros((maxlen_y, n_samples)).astype('float32')
 
    for idx, [s_x, s_y] in enumerate(zip(images_x, seqs_y)):

        #s_x=(s_x / 255.) # [B, C, H, W] -> [B, H, W, C]       
        if len(s_x.shape)>2:
          s_x= cv2.cvtColor(s_x.astype('float32'), cv2.COLOR_BGR2GRAY)
        x[idx, :heights_x[idx], :widths_x[idx]] = s_x/255.
        x_mask[idx, :heights_x[idx], :widths_x[idx]] = 1.
        y[:lengths_y[idx], idx] = s_y
        y_mask[:lengths_y[idx], idx] = 1.
    return x, x_mask, y, y_mask

def main():
	#### encoder setup parameters ####
  dense_blocks=3
  levels_count=16
  growth=24
  #### decoder setup parameters ####
  hidden_dim=256
  word_dim=256
  dim_attend=512

  #### Loding Data ####
  train,train_uid_list = dataIterator('/gdrive/My Drive/URDU_V2_Exp/wap_exp/data/trainImages_v2.pkl','/gdrive/My Drive/URDU_V2_Exp/wap_exp/data/trainLabels_v2.pkl', batch_size=BATCHSIZE, batch_Imagesize=500000,maxlen=130, maxImagesize=500000)
  valid, valid_uid_list = dataIterator('/gdrive/My Drive/URDU_V2_Exp/wap_exp/testdata/validImages_v2.pkl','/gdrive/My Drive/URDU_V2_Exp/wap_exp/testdata/validLabels_v2.pkl',batch_size=BATCHSIZE, batch_Imagesize=500000,maxlen=130, maxImagesize=500000)
  test, test_uid_list = dataIterator('/gdrive/My Drive/URDU_V2_Exp/wap_exp/testdata/testImages_v2.pkl','/gdrive/My Drive/URDU_V2_Exp/wap_exp/testdata/testLabels_v2.pkl',batch_size=1, batch_Imagesize=500000,maxlen=130, maxImagesize=500000)

  #### Graph Creation ####
  with tf.Graph().as_default():
	
    x = tf.placeholder(tf.float32, shape=[None, None, None])	
    y = tf.placeholder(tf.int32, shape=[None, None])	
   
    x_mask = tf.placeholder(tf.float32, shape=[None, None, None])	
    y_mask = tf.placeholder(tf.float32, shape=[None, None])
	
    lr = tf.placeholder(tf.float32, shape=())

    global anno, infer_y, h_pre, alpha_past, if_trainning
		
    if_trainning = tf.placeholder(tf.bool, shape=())
    alpha_reg = tf.placeholder(tf.float32, shape=())

    watcher_train = Watcher_train(blocks=dense_blocks,level=levels_count, growth_rate=growth, training=if_trainning)	
    annotation, anno_mask = watcher_train.dense_net(x, x_mask)

    # for initilaizing validation
    anno = tf.placeholder(tf.float32, shape=[None, annotation.shape.as_list()[1], annotation.shape.as_list()[2], annotation.shape.as_list()[3]])
    infer_y = tf.placeholder(tf.int64, shape=(None,))
    h_pre = tf.placeholder(tf.float32, shape=[None, hidden_dim])
    alpha_past = tf.placeholder(tf.float32, shape=[None, annotation.shape.as_list()[1], annotation.shape.as_list()[2]])
	
    attender = Attender(annotation.shape.as_list()[3], hidden_dim, dim_attend)	
    parser = Parser(hidden_dim, word_dim, attender, annotation.shape.as_list()[3])	
    wap = WAP(watcher_train, attender, parser,hidden_dim,word_dim, annotation.shape.as_list()[3], num_classes, if_trainning)
	
    hidden_state_0 = tf.tanh(tf.tensordot(tf.reduce_mean(anno, axis=[1, 2]), wap.Wa2h, axes=1) + wap.ba2h)  # [batch, hidden_dim]
	  
    cost = wap.get_cost(annotation, y, anno_mask, y_mask, alpha_reg)
    
		#### regularization of densenet variables ####
    #alpha_c=0.5
    vs = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
    for vv in vs:
    	if not vv.name.startswith('conv2d'):
    		cost += 1e-4 * tf.reduce_sum(tf.pow(vv, 2))
			  

    p, w, h, alpha = wap.get_word(infer_y, h_pre, alpha_past, anno)
    optimizer = tf.train.AdadeltaOptimizer(learning_rate=lr)

		#### gradients clipping ####
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):			
    	tvars= tf.trainable_variables()
    	grads, _  = tf.clip_by_global_norm(tf.gradients(cost, tvars),100)
    	trainer = optimizer.apply_gradients(zip(grads, tvars))
    	
      
			
    
    uidx = 0
    cost_s = 0
    cost_i=0
    dispFreq = 200
    saveFreq = len(train)
    sampleFreq = len(train)
    validFreq = len(train)
    history_errs = []
    estop = False
    halfLrFlag = 0
    patience = 15
    lrate = 1
    log = open('/gdrive/My Drive/URDU_V2_Exp/wap_exp/result/log.txt', 'w')
    config = tf.ConfigProto()
    config.gpu_options.allow_growth=True
    init = tf.global_variables_initializer()
    
    saver = tf.train.Saver(max_to_keep=None)
    checkpoint_path = '/gdrive/My Drive/URDU_V2_Exp/wap_exp/model2/cp-{ckpt:04d}-{acc:04f}.ckpt'

    with tf.Session(config=config) as sess:
    	sess.run(init)
    	saver.restore(sess, '/gdrive/My Drive/URDU_V2_Exp/wap_exp/model2/cp-0045-25.278371.ckpt')#tf.train.latest_checkpoint('/gdrive/My Drive/urdu final v2/model2/'))cp-0045-25.278371.ckpt
    	###Testing mode ###	
    	if(MODE=='train'):
    		for epoch in range(1,50):
    			n_samples = 0
    			random.shuffle(train)
    			start_time = time.time()
    
    			for batch_x, batch_y in train:
    				batch_x, batch_x_m, batch_y, batch_y_m = prepare_data(batch_x, batch_y,BATCHSIZE)         
    				uidx += 1
    				cost_i, _ = sess.run([cost,trainer],feed_dict={x:batch_x, y:batch_y, x_mask:batch_x_m, y_mask:batch_y_m, if_trainning:True, lr:lrate, alpha_reg:1})
    				cost_s +=cost_i
    				if np.isnan(cost_i) or np.isinf(cost_i):
    					print('invalid cost value detected')
    					sys.exit(0)
    				if np.mod(uidx, dispFreq) == 0:
    					cost_s /= dispFreq
    					print('Epoch ', epoch, 'Update ', uidx, 'Cost ', cost_s, 'Lr ', lrate)
    					log.write('Epoch ' + str(epoch) + ' Update ' + str(uidx) + ' Cost ' + str(cost_s) + ' Lr ' + str(lrate) + '\n')
    					log.flush()
    					cost_s = 0
    				if np.mod(uidx,sampleFreq) == 0:
    					fpp_sample = open('/gdrive/My Drive/URDU_V2_Exp/wap_exp/result/valid_predicted.txt', 'w')
    					fpp_sample2=open('/gdrive/My Drive/URDU_V2_Exp/wap_exp/result/valid_target.txt', 'w')
    					valid_count_idx = 0  
    					for batch_x, batch_y in valid:
    						for idxx,[xx,yy] in enumerate(zip(batch_x,batch_y)): 
    							if len(xx.shape)>2:
    								xx= cv2.cvtColor(xx.astype('float32'), cv2.COLOR_BGR2GRAY)
    							xx_pad = np.zeros((xx.shape[0], xx.shape[1]), dtype='float32')
    							xx_pad[:xx.shape[0],:xx.shape[1]] =xx/255.
    							xx_pad = xx_pad[None, :, :]
    							annot = sess.run(annotation, feed_dict={x:xx_pad, if_trainning:False})
    							h_state = sess.run(hidden_state_0, feed_dict={anno:annot})
    							sample, score,hypalpha = wap.get_sample( p, w, h, alpha,annot, h_state,10, 130, False, sess, training=False)
    							score = score / np.array([len(s) for s in sample])
    							ss = sample[score.argmin()]
    							fpp_sample.write(str(valid_uid_list[valid_count_idx])+'im')
    							fpp_sample2.write(str(valid_uid_list[valid_count_idx])+'im')
    							valid_count_idx=valid_count_idx+1
    							if np.mod(valid_count_idx, 200) == 0:
    								print('gen %d samples'%valid_count_idx)
    								log.write('gen %d samples'%valid_count_idx + '\n')
    								log.flush()
    							for vv in ss:
    								if vv == 0: # <eol>
    									break
    								fpp_sample.write(' '+str(vv))
    							for y1 in yy:
    								fpp_sample2.write(' '+str(y1))
    							fpp_sample.write('\n')
    							fpp_sample2.write('\n')
    					fpp_sample.close()
    					fpp_sample2.close()
    					print('valid set decode done')
    					log.write('valid set decode done\n')
    					log.flush()
    				if np.mod(uidx, validFreq) == 0:
    					probs = []
    					for batch_x, batch_y in valid:
    						batch_x, batch_x_m, batch_y, batch_y_m = prepare_data(batch_x, batch_y,BATCHSIZE)
    						pprobs, annott = sess.run([cost, annotation], feed_dict={x:batch_x, y:batch_y, x_mask:batch_x_m, y_mask:batch_y_m, if_trainning:False, alpha_reg:1})
    						probs.append(pprobs)
    					valid_errs = np.array(probs)
    					valid_err_cost = valid_errs.mean()
    					process( '/gdrive/My Drive/URDU_V2_Exp/wap_exp/result/valid_predicted.txt', '/gdrive/My Drive/URDU_V2_Exp/wap_exp/result/valid_target.txt' , '/gdrive/My Drive/URDU_V2_Exp/wap_exp/result/valid-wer.wer')
    					fpp=open('/gdrive/My Drive/URDU_V2_Exp/wap_exp/result/valid-wer.wer')
    					stuff=fpp.readlines()
    					fpp.close()
    					m=re.search('CER (.*)\n',stuff[0])
    					valid_cer=100. * float(m.group(1))
    					m=re.search('WER (.*)\n',stuff[1])
    					valid_wer=100. * float(m.group(1))
    					m=re.search('LER (.*)\n',stuff[2])
    					valid_ler=100. * float(m.group(1))
    					valid_err=valid_cer
    					history_errs.append(valid_err)
    					if uidx/validFreq == 0 or valid_err <= np.array(history_errs).min():
    						bad_counter = 0
    					if uidx/validFreq != 0 and valid_err > np.array(history_errs).min():
    						bad_counter += 1
    						if bad_counter > patience:
    							if halfLrFlag==2:
    								print('Early Stop!')
    								log.write('Early Stop!\n')
    								log.flush()
    								estop = True
    								break
    							else:
    								print('Lr decay and retrain!')
    								log.write('Lr decay and retrain!\n')
    								log.flush()
    								bad_counter = 0
    								lrate = lrate / 10
    								halfLrFlag += 1
	
    					print('Valid CER: %.2f%%,Valid WER: %.2f%%, ExpRate: %.2f%%, Cost: %f' % (valid_cer,valid_wer,valid_ler,valid_err_cost))
    					log.write('Valid CER: %.2f%%,Valid WER: %.2f%%, ExpRate: %.2f%%, Cost: %f' % (valid_cer,valid_wer,valid_ler,valid_err_cost) + '\n')
    					log.flush()
	
    					duration = time.time() - start_time
    					print('1 itr completion time: %.2f%%' %(duration)+'\n')
    					#save_path = saver.save(sess, checkpoint_path.format(ckpt=epoch,acc=valid_cer))
    			if estop:
    				break
    	###Testing mode ###			
    	if(MODE=='test'):
    		fpp_sample = open('/gdrive/My Drive/URDU_V2_Exp/wap_exp/result/test_predicted.txt', 'w')
    		fpp_sample2=open('/gdrive/My Drive/URDU_V2_Exp/wap_exp/result/test_target.txt', 'w')

    		test_count_idx = 0  
    		for batch_x, batch_y in test:
    			for idxx,[xx,yy] in enumerate(zip(batch_x,batch_y)): 
    				if len(xx.shape)>2:
    					xx= cv2.cvtColor(xx.astype('float32'), cv2.COLOR_BGR2GRAY)
    				xx_pad = np.zeros((xx.shape[0], xx.shape[1]), dtype='float32')
    				xx_pad[:xx.shape[0],:xx.shape[1]] =xx/255.
    				xx_pad = xx_pad[None, :, :]
    				annot = sess.run(annotation, feed_dict={x:xx_pad, if_trainning:False})
    				h_state = sess.run(hidden_state_0, feed_dict={anno:annot})
    				sample, score,hypalpha = wap.get_sample( p, w, h, alpha,annot, h_state,10, 130, False, sess, training=False)
    				score = score / np.array([len(s) for s in sample])
    				ss = sample[score.argmin()]
    					
    				fpp_sample.write(str(valid_uid_list[test_count_idx])+'im')
    				fpp_sample2.write(str(valid_uid_list[test_count_idx])+'im')
    				test_count_idx=test_count_idx+1
    				if np.mod(test_count_idx, 200) == 0:
    					print('gen %d samples'%test_count_idx)
    					log.write('gen %d samples'%test_count_idx + '\n')
    					log.flush()
    				for vv in ss:
    					if vv == 0: # <eol>
    						break
    					fpp_sample.write(' '+str(vv))
    				for y1 in yy:
    					fpp_sample2.write(' '+str(y1))
    				fpp_sample.write('\n')
    				fpp_sample2.write('\n')
    		fpp_sample.close()
    		fpp_sample2.close()
    		print('test set decode done')
    		log.write('test set decode done\n')
    		log.flush()		
    		process( '/gdrive/My Drive/URDU_V2_Exp/wap_exp/result/test_predicted.txt', '/gdrive/My Drive/URDU_V2_Exp/wap_exp/result/test_target.txt' , '/gdrive/My Drive/URDU_V2_Exp/wap_exp/result/test-wer.wer')
    		fpp=open('/gdrive/My Drive/URDU_V2_Exp/wap_exp/result/test-wer.wer')
    		stuff=fpp.readlines()
    		fpp.close()
    		m=re.search('CER (.*)\n',stuff[0])
    		test_cer=100. * float(m.group(1))
    		m=re.search('WER (.*)\n',stuff[1])
    		test_wer=100. * float(m.group(1))
    		m=re.search('LER (.*)\n',stuff[2])
    		test_ler=100. * float(m.group(1)) 			
    		print('test CER: %.2f%%,test WER: %.2f%%, ExpRate: %.2f%%' % (test_cer,valid_wer,test_ler))
    		log.write('test CER: %.2f%%,test WER: %.2f%%, ExpRate: %.2f%%' % (test_cer,valid_wer,test_ler) + '\n')
    		log.flush()		 
			
if __name__ == "__main__":
	main()
