Skip to content

Commit

Permalink
Implementation of resizing codec
Browse files Browse the repository at this point in the history
  • Loading branch information
ChWick committed Feb 20, 2018
1 parent d3e5cc6 commit c952a85
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 6 deletions.
61 changes: 61 additions & 0 deletions ocrolib/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,19 @@ def __init__(self,Nh,No,initial_range=initial_range,rand=np.random.rand):
self.No = No
self.W2 = randu(No,Nh+1)*initial_range
self.DW2 = np.zeros((No,Nh+1))
self.initial_range = initial_range
def resizeOutput(self,No, deleted_positions):
"""resize all matrices to the new codec created by a given charset"""
# delete rows for chars that are not necessary
W2_temp = np.delete(self.W2, deleted_positions, axis=0)
# enlarge output and weights for extra chars
W2 = randu(No, self.Nh + 1) * initial_range
# use the trained weights (if --load was used)
W2[: len(W2_temp)] = W2_temp
self.W2 = W2
self.DW2 = np.zeros((No, self.Nh + 1))
self.No = No
self.deltas = None
def ninputs(self):
return self.Nh
def noutputs(self):
Expand Down Expand Up @@ -577,6 +590,8 @@ def backward(self,deltas):
self.DWGI,self.DWGF,self.DWGO,self.DWCI,
self.DWIP,self.DWFP,self.DWOP)
return [s[1:1+ni] for s in self.sourceerr[:n]]
def resizeOutput(self, No, deleted_positions):
pass

################################################################
# combination classifiers
Expand Down Expand Up @@ -620,6 +635,9 @@ def weights(self):
for i,net in enumerate(self.nets):
for w,dw,n in net.weights():
yield w,dw,"Stacked%d/%s"%(i,n)
def resizeOutput(self, nout, deleted_positions):
self.nets[-1].resizeOutput(nout, deleted_positions)
self.deltas = None

class Reversed(Network):
"""Run a network on the time-reversed input."""
Expand All @@ -644,6 +662,8 @@ def states(self):
def weights(self):
for w,dw,n in self.net.weights():
yield w,dw,"Reversed/%s"%n
def resizeOutput(self, no, deleted_positions):
self.net.resizeOutput(no, deleted_positions)

class Parallel(Network):
"""Run multiple networks in parallel on the same input."""
Expand Down Expand Up @@ -678,6 +698,9 @@ def weights(self):
for i,net in enumerate(self.nets):
for w,dw,n in net.weights():
yield w,dw,"Parallel%d/%s"%(i,n)
def resizeOutput(self, no, deleted_positions):
for net in self.nets:
net.resizeOutput(no, deleted_positions)

def MLP1(Ni,Ns,No):
"""An MLP implementation by stacking two `Logreg` networks on top
Expand Down Expand Up @@ -932,6 +955,17 @@ def predictString(self,xs):
"Predict output as a string. This uses codec and normalizer."
cs = self.predictSequence(xs)
return self.l2s(cs)
def resizeCodec(self, codec):
"""create a codec that exactly fits to ground truth/given codec as parameter"""
print("# creating a codec thas fits to the given charset")
# add all unknown and new chars to the codec
self.codec.extend(codec)
# search for chars that should not be in the codec anymore
deleted_positions = self.codec.shrink(codec)
# let the output fit to the new defined codec
self.lstm.resizeOutput(self.codec.size(), deleted_positions)
self.No = self.codec.size()
return self.codec

class Codec:
"""Translate between integer codes and characters."""
Expand All @@ -956,6 +990,33 @@ def decode(self,l):
"Decode a code sequence into a string."
s = [self.code2char.get(c,"~") for c in l]
return s
def extend(self, codec):
charset = self.code2char.values()
size = self.size()
counter = 0
for c in codec.code2char.values():
if not c in charset: # append chars that doesn't appear in the codec
self.code2char[size] = c
self.char2code[c] = size
size += 1
counter += 1
print("#", counter, " extra chars added")
def shrink(self, codec):
deleted_positions = []
positions = []
for number, char in self.code2char.iteritems():
if not char in codec.char2code and char != "~":
deleted_positions.append(number)
else:
positions.append(number)
charset = [self.code2char[c] for c in sorted(positions)]
self.code2char = {}
self.char2code = {}
for code, char in enumerate(charset):
self.code2char[code] = char
self.char2code[char] = code
print("#", len(deleted_positions), " unnecessary chars deleted")
return deleted_positions

ascii_labels = [""," ","~"] + [unichr(x) for x in range(33,126)]

Expand Down
21 changes: 15 additions & 6 deletions ocropus-rtrain
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def save_lstm(fname,network):
for x in network.walk(): x.postLoad()


def load_lstm(fname):
def load_lstm(fname, codec):
if args.clstm:
network = lstm.SeqRecognizer(args.height,args.hiddensize,
codec=codec,
Expand All @@ -188,17 +188,27 @@ def load_lstm(fname):
mylstm.init(network.No,args.hiddensize,network.Ni)
mylstm.load(fname)
network.lstm = clstm.CNetwork(mylstm)
return network
else:
network = ocrolib.load_object(last_save)
network.upgrade()
for x in network.walk(): x.postLoad()
return network

# if a model was loaded we must change the local codec in any case
# either resize the codec of the network if a codec is given
# or use the loaded codec directly
if args.codec != []:
# resize the network codec (including the network weights)
codec = network.resizeCodec(codec)
else:
# the local codec is simply the local codec
codec = network.codec

return network, codec

if args.load:
print("# loading", args.load)
last_save = args.load
network = load_lstm(args.load)
network, codec = load_lstm(args.load, codec)
else:
last_save = None
network = lstm.SeqRecognizer(args.height,args.hiddensize,
Expand Down Expand Up @@ -306,8 +316,7 @@ for trial in range(start,args.ntrain):
except FloatingPointError as e:
print("# oops, got FloatingPointError", e)
traceback.print_exc()
network = load_lstm(last_save)
continue
network, codec = load_lstm(last_save, codec)
except lstm.RangeError as e:
continue
pred = "".join(codec.decode(pcs))
Expand Down

0 comments on commit c952a85

Please sign in to comment.