Skip to content

Commit

Permalink
Implementation of resizing codec
Browse files Browse the repository at this point in the history
  • Loading branch information
ChWick committed Dec 15, 2017
1 parent 3015bf4 commit f33831a
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 8 deletions.
65 changes: 63 additions & 2 deletions ocrolib/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from __future__ import print_function

import common as ocrolib
from numpy import (amax, amin, argmax, arange, array, clip, concatenate, dot,
from numpy import (amax, amin, argmax, arange, array, clip, concatenate, delete, dot,
exp, isnan, log, maximum, mean, nan, ones, outer, roll, sum,
tanh, tile, vstack, zeros)
from pylab import (clf, cm, figure, ginput, imshow, newaxis, rand, subplot,
Expand Down Expand Up @@ -288,6 +288,19 @@ def __init__(self,Nh,No,initial_range=initial_range,rand=rand):
self.No = No
self.W2 = randu(No,Nh+1)*initial_range
self.DW2 = zeros((No,Nh+1))
self.initial_range = initial_range
def resizeOutput(self,No, deleted_positions):
"""resize all matrices to the new codec created by a given charset"""
# delete rows for chars that are not necessary
W2_temp = delete(self.W2, deleted_positions, axis=0)
# enlarge output and weights for extra chars
W2 = randu(No, self.Nh + 1) * initial_range
# use the trained weights (if --load was used)
W2[: len(W2_temp)] = W2_temp
self.W2 = W2
self.DW2 = zeros((No, self.Nh + 1))
self.No = No
self.deltas = None
def ninputs(self):
return self.Nh
def noutputs(self):
Expand Down Expand Up @@ -578,6 +591,8 @@ def backward(self,deltas):
self.DWGI,self.DWGF,self.DWGO,self.DWCI,
self.DWIP,self.DWFP,self.DWOP)
return [s[1:1+ni] for s in self.sourceerr[:n]]
def resizeOutput(self, No, deleted_positions):
pass

################################################################
# combination classifiers
Expand Down Expand Up @@ -621,6 +636,9 @@ def weights(self):
for i,net in enumerate(self.nets):
for w,dw,n in net.weights():
yield w,dw,"Stacked%d/%s"%(i,n)
def resizeOutput(self, nout, deleted_positions):
self.nets[-1].resizeOutput(nout, deleted_positions)
self.deltas = None

class Reversed(Network):
"""Run a network on the time-reversed input."""
Expand All @@ -645,6 +663,8 @@ def states(self):
def weights(self):
for w,dw,n in self.net.weights():
yield w,dw,"Reversed/%s"%n
def resizeOutput(self, no, deleted_positions):
self.net.resizeOutput(no, deleted_positions)

class Parallel(Network):
"""Run multiple networks in parallel on the same input."""
Expand Down Expand Up @@ -679,6 +699,9 @@ def weights(self):
for i,net in enumerate(self.nets):
for w,dw,n in net.weights():
yield w,dw,"Parallel%d/%s"%(i,n)
def resizeOutput(self, no, deleted_positions):
for net in self.nets:
net.resizeOutput(no, deleted_positions)

def MLP1(Ni,Ns,No):
"""An MLP implementation by stacking two `Logreg` networks on top
Expand Down Expand Up @@ -839,7 +862,7 @@ def ctc_align_targets(outputs,targets,threshold=100.0,verbose=0,debug=0,lo=1e-5)
return aligned

def normalize_nfkc(s):
return unicodedata.normalize('NFKC',s)
return unicodedata.normalize('NFC',s)

def add_training_info(network):
return network
Expand Down Expand Up @@ -933,6 +956,17 @@ def predictString(self,xs):
"Predict output as a string. This uses codec and normalizer."
cs = self.predictSequence(xs)
return self.l2s(cs)
def resizeCodec(self, codec):
"""create a codec that exactly fits to ground truth/given codec as parameter"""
print("# creating a codec thas fits to the given charset")
# add all unknown and new chars to the codec
self.codec.extend(codec)
# search for chars that should not be in the codec anymore
deleted_positions = self.codec.shrink(codec)
# let the output fit to the new defined codec
self.lstm.resizeOutput(self.codec.size(), deleted_positions)
self.No = self.codec.size()
return self.codec

class Codec:
"""Translate between integer codes and characters."""
Expand All @@ -957,6 +991,33 @@ def decode(self,l):
"Decode a code sequence into a string."
s = [self.code2char.get(c,"~") for c in l]
return s
def extend(self, codec):
charset = self.code2char.values()
size = self.size()
counter = 0
for c in codec.code2char.values():
if not c in charset: # append chars that doesn't appear in the codec
self.code2char[size] = c
self.char2code[c] = size
size += 1
counter += 1
print("#", counter, " extra chars added")
def shrink(self, codec):
deleted_positions = []
positions = []
for number, char in self.code2char.iteritems():
if not char in codec.char2code and char != "~":
deleted_positions.append(number)
else:
positions.append(number)
charset = [self.code2char[c] for c in sorted(positions)]
self.code2char = {}
self.char2code = {}
for code, char in enumerate(charset):
self.code2char[code] = char
self.char2code[char] = code
print("#", len(deleted_positions), " unnecessary chars deleted")
return deleted_positions

ascii_labels = [""," ","~"] + [unichr(x) for x in range(33,126)]

Expand Down
21 changes: 15 additions & 6 deletions ocropus-rtrain
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def save_lstm(fname,network):
for x in network.walk(): x.postLoad()


def load_lstm(fname):
def load_lstm(fname, codec):
if args.clstm:
network = lstm.SeqRecognizer(args.height,args.hiddensize,
codec=codec,
Expand All @@ -178,17 +178,27 @@ def load_lstm(fname):
mylstm.init(network.No,args.hiddensize,network.Ni)
mylstm.load(fname)
network.lstm = clstm.CNetwork(mylstm)
return network
else:
network = ocrolib.load_object(last_save)
network.upgrade()
for x in network.walk(): x.postLoad()
return network

# if a model was loaded we must change the local codec in any case
# either resize the codec of the network if a codec is given
# or use the loaded codec directly
if args.codec != []:
# resize the network codec (including the network weights)
codec = network.resizeCodec(codec)
else:
# the local codec is simply the local codec
codec = network.codec

return network, codec

if args.load:
print("# loading", args.load)
last_save = args.load
network = load_lstm(args.load)
network, codec = load_lstm(args.load, codec)
else:
last_save = None
network = lstm.SeqRecognizer(args.height,args.hiddensize,
Expand Down Expand Up @@ -296,8 +306,7 @@ for trial in range(start,args.ntrain):
except FloatingPointError as e:
print("# oops, got FloatingPointError", e)
traceback.print_exc()
network = load_lstm(last_save)
continue
network, codec = load_lstm(last_save, codec)
except lstm.RangeError as e:
continue
pred = "".join(codec.decode(pcs))
Expand Down

0 comments on commit f33831a

Please sign in to comment.