Skip to content

Commit

Permalink
Implementation of resizing codec
Browse files Browse the repository at this point in the history
  • Loading branch information
ChWick committed Dec 13, 2017
1 parent 3015bf4 commit 64f82fa
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 3 deletions.
65 changes: 63 additions & 2 deletions ocrolib/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from __future__ import print_function

import common as ocrolib
from numpy import (amax, amin, argmax, arange, array, clip, concatenate, dot,
from numpy import (amax, amin, argmax, arange, array, clip, concatenate, delete, dot,
exp, isnan, log, maximum, mean, nan, ones, outer, roll, sum,
tanh, tile, vstack, zeros)
from pylab import (clf, cm, figure, ginput, imshow, newaxis, rand, subplot,
Expand Down Expand Up @@ -288,6 +288,19 @@ def __init__(self,Nh,No,initial_range=initial_range,rand=rand):
self.No = No
self.W2 = randu(No,Nh+1)*initial_range
self.DW2 = zeros((No,Nh+1))
self.initial_range = initial_range
def resizeoutput(self,No, deleted_positions):
"""resize all matrices to the new codec created by a given charset"""
# delete rows for chars that are not necessary
W2_temp = delete(self.W2, deleted_positions, axis=0)
# enlarge output and weights for extra chars
W2 = randu(No, self.Nh + 1) * initial_range
# use the trained weights (if --load was used)
W2[: len(W2_temp)] = W2_temp
self.W2 = W2
self.DW2 = zeros((No, self.Nh + 1))
self.No = No
self.deltas = None
def ninputs(self):
return self.Nh
def noutputs(self):
Expand Down Expand Up @@ -578,6 +591,8 @@ def backward(self,deltas):
self.DWGI,self.DWGF,self.DWGO,self.DWCI,
self.DWIP,self.DWFP,self.DWOP)
return [s[1:1+ni] for s in self.sourceerr[:n]]
def resizeoutput(self, No, deleted_positions):
pass

################################################################
# combination classifiers
Expand Down Expand Up @@ -621,6 +636,9 @@ def weights(self):
for i,net in enumerate(self.nets):
for w,dw,n in net.weights():
yield w,dw,"Stacked%d/%s"%(i,n)
def resizeoutput(self, nout, deleted_positions):
self.nets[-1].resizeoutput(nout, deleted_positions)
self.deltas = None

class Reversed(Network):
"""Run a network on the time-reversed input."""
Expand All @@ -645,6 +663,8 @@ def states(self):
def weights(self):
for w,dw,n in self.net.weights():
yield w,dw,"Reversed/%s"%n
def resizeoutput(self, no, deleted_positions):
self.net.resizeOutput(no, deleted_positions)

class Parallel(Network):
"""Run multiple networks in parallel on the same input."""
Expand Down Expand Up @@ -679,6 +699,9 @@ def weights(self):
for i,net in enumerate(self.nets):
for w,dw,n in net.weights():
yield w,dw,"Parallel%d/%s"%(i,n)
def resizeoutput(self, no, deleted_positions):
for net in self.nets:
net.resizeoutput(no, deleted_positions)

def MLP1(Ni,Ns,No):
"""An MLP implementation by stacking two `Logreg` networks on top
Expand Down Expand Up @@ -839,7 +862,7 @@ def ctc_align_targets(outputs,targets,threshold=100.0,verbose=0,debug=0,lo=1e-5)
return aligned

def normalize_nfkc(s):
return unicodedata.normalize('NFKC',s)
return unicodedata.normalize('NFC',s)

def add_training_info(network):
return network
Expand Down Expand Up @@ -933,6 +956,17 @@ def predictString(self,xs):
"Predict output as a string. This uses codec and normalizer."
cs = self.predictSequence(xs)
return self.l2s(cs)
def extendCodec(self, codec):
"""create a codec that exactly fits to ground truth/given codec as parameter"""
print("# creating a codec thas fits to the given charset")
# add all unknown and new chars to the codec
self.codec.extend(codec)
# search for chars that should not be in the codec anymore
deleted_positions = self.codec.shrink(codec)
# let the output fit to the new defined codec
self.lstm.resizeoutput(self.codec.size(), deleted_positions)
self.No = self.codec.size()
return self.codec

class Codec:
"""Translate between integer codes and characters."""
Expand All @@ -957,6 +991,33 @@ def decode(self,l):
"Decode a code sequence into a string."
s = [self.code2char.get(c,"~") for c in l]
return s
def extend(self, codec):
charset = self.code2char.values()
size = self.size()
counter = 0
for c in codec.code2char.values():
if not c in charset: # append chars that doesn't appear in the codec
self.code2char[size] = c
self.char2code[c] = size
size += 1
counter += 1
print("#", counter, " extra chars added")
def shrink(self, codec):
deleted_positions = []
positions = []
for number, char in self.code2char.iteritems():
if not char in codec.char2code and char != "~":
deleted_positions.append(number)
else:
positions.append(number)
charset = [self.code2char[c] for c in sorted(positions)]
self.code2char = {}
self.char2code = {}
for code, char in enumerate(charset):
self.code2char[code] = char
self.char2code[char] = code
print("#", len(deleted_positions), " unnecessary chars deleted")
return deleted_positions

ascii_labels = [""," ","~"] + [unichr(x) for x in range(33,126)]

Expand Down
8 changes: 7 additions & 1 deletion ocropus-rtrain
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,11 @@ network.upgrade()
if network.last_trial%100==99: network.last_trial += 1
print("# last_trial", network.last_trial)

# required if model was loaded
if args.codec != []:
codec = network.extendCodec(codec)
else:
codec = network.codec

# set up the learning rate

Expand Down Expand Up @@ -296,7 +301,8 @@ for trial in range(start,args.ntrain):
except FloatingPointError as e:
print("# oops, got FloatingPointError", e)
traceback.print_exc()
network = load_lstm(last_save)
# TODO: adjust codec if an error occurred
# network = load_lstm(last_save)
continue
except lstm.RangeError as e:
continue
Expand Down

0 comments on commit 64f82fa

Please sign in to comment.