Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of resizing codec #277

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 61 additions & 0 deletions ocrolib/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,19 @@ def __init__(self,Nh,No,initial_range=initial_range,rand=np.random.rand):
self.No = No
self.W2 = randu(No,Nh+1)*initial_range
self.DW2 = np.zeros((No,Nh+1))
self.initial_range = initial_range
def resizeOutput(self,No, deleted_positions):
"""resize all matrices to the new codec created by a given charset"""
# delete rows for chars that are not necessary
W2_temp = np.delete(self.W2, deleted_positions, axis=0)
# enlarge output and weights for extra chars
W2 = randu(No, self.Nh + 1) * initial_range
# use the trained weights (if --load was used)
W2[: len(W2_temp)] = W2_temp
self.W2 = W2
self.DW2 = np.zeros((No, self.Nh + 1))
self.No = No
self.deltas = None
def ninputs(self):
return self.Nh
def noutputs(self):
Expand Down Expand Up @@ -577,6 +590,8 @@ def backward(self,deltas):
self.DWGI,self.DWGF,self.DWGO,self.DWCI,
self.DWIP,self.DWFP,self.DWOP)
return [s[1:1+ni] for s in self.sourceerr[:n]]
def resizeOutput(self, No, deleted_positions):
pass

################################################################
# combination classifiers
Expand Down Expand Up @@ -620,6 +635,9 @@ def weights(self):
for i,net in enumerate(self.nets):
for w,dw,n in net.weights():
yield w,dw,"Stacked%d/%s"%(i,n)
def resizeOutput(self, nout, deleted_positions):
self.nets[-1].resizeOutput(nout, deleted_positions)
self.deltas = None

class Reversed(Network):
"""Run a network on the time-reversed input."""
Expand All @@ -644,6 +662,8 @@ def states(self):
def weights(self):
for w,dw,n in self.net.weights():
yield w,dw,"Reversed/%s"%n
def resizeOutput(self, no, deleted_positions):
self.net.resizeOutput(no, deleted_positions)

class Parallel(Network):
"""Run multiple networks in parallel on the same input."""
Expand Down Expand Up @@ -678,6 +698,9 @@ def weights(self):
for i,net in enumerate(self.nets):
for w,dw,n in net.weights():
yield w,dw,"Parallel%d/%s"%(i,n)
def resizeOutput(self, no, deleted_positions):
for net in self.nets:
net.resizeOutput(no, deleted_positions)

def MLP1(Ni,Ns,No):
"""An MLP implementation by stacking two `Logreg` networks on top
Expand Down Expand Up @@ -932,6 +955,17 @@ def predictString(self,xs):
"Predict output as a string. This uses codec and normalizer."
cs = self.predictSequence(xs)
return self.l2s(cs)
def resizeCodec(self, codec):
"""create a codec that exactly fits to ground truth/given codec as parameter"""
print("# creating a codec thas fits to the given charset")
# add all unknown and new chars to the codec
self.codec.extend(codec)
# search for chars that should not be in the codec anymore
deleted_positions = self.codec.shrink(codec)
# let the output fit to the new defined codec
self.lstm.resizeOutput(self.codec.size(), deleted_positions)
self.No = self.codec.size()
return self.codec

class Codec:
"""Translate between integer codes and characters."""
Expand All @@ -956,6 +990,33 @@ def decode(self,l):
"Decode a code sequence into a string."
s = [self.code2char.get(c,"~") for c in l]
return s
def extend(self, codec):
charset = self.code2char.values()
size = self.size()
counter = 0
for c in codec.code2char.values():
if not c in charset: # append chars that doesn't appear in the codec
self.code2char[size] = c
self.char2code[c] = size
size += 1
counter += 1
print("#", counter, " extra chars added")
def shrink(self, codec):
deleted_positions = []
positions = []
for number, char in self.code2char.iteritems():
if not char in codec.char2code and char != "~":
deleted_positions.append(number)
else:
positions.append(number)
charset = [self.code2char[c] for c in sorted(positions)]
self.code2char = {}
self.char2code = {}
for code, char in enumerate(charset):
self.code2char[code] = char
self.char2code[char] = code
print("#", len(deleted_positions), " unnecessary chars deleted")
return deleted_positions

ascii_labels = [""," ","~"] + [unichr(x) for x in range(33,126)]

Expand Down
20 changes: 15 additions & 5 deletions ocropus-rtrain
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def save_lstm(fname,network):
for x in network.walk(): x.postLoad()


def load_lstm(fname):
def load_lstm(fname, codec):
if args.clstm:
network = lstm.SeqRecognizer(args.height,args.hiddensize,
codec=codec,
Expand All @@ -188,17 +188,27 @@ def load_lstm(fname):
mylstm.init(network.No,args.hiddensize,network.Ni)
mylstm.load(fname)
network.lstm = clstm.CNetwork(mylstm)
return network
else:
network = ocrolib.load_object(last_save)
network.upgrade()
for x in network.walk(): x.postLoad()
return network

# if a model was loaded we must change the local codec in any case
# either resize the codec of the network if a codec is given
# or use the loaded codec directly
if args.codec != []:
# resize the network codec (including the network weights)
codec = network.resizeCodec(codec)
else:
# the local codec is simply the local codec
codec = network.codec

return network, codec

if args.load:
print("# loading", args.load)
last_save = args.load
network = load_lstm(args.load)
network, codec = load_lstm(args.load, codec)
else:
last_save = None
network = lstm.SeqRecognizer(args.height,args.hiddensize,
Expand Down Expand Up @@ -306,7 +316,7 @@ for trial in range(start,args.ntrain):
except FloatingPointError as e:
print("# oops, got FloatingPointError", e)
traceback.print_exc()
network = load_lstm(last_save)
network, codec = load_lstm(last_save, codec)
continue
except lstm.RangeError as e:
continue
Expand Down