-
Notifications
You must be signed in to change notification settings - Fork 588
/
ocropus-ctrain
executable file
·339 lines (277 loc) · 12 KB
/
ocropus-ctrain
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
#!/usr/bin/python
import sys,os,re,glob,math,glob,signal,traceback
import matplotlib
if "DISPLAY" not in os.environ: matplotlib.use("AGG")
else: matplotlib.use("GTK")
import random as pyrandom
from optparse import OptionParser
from pylab import *
from scipy.ndimage import interpolation,filters
import ocrolib
from ocrolib import docproc,Record,mlp,dbhelper
from scipy import stats
signal.signal(signal.SIGINT,lambda *args:sys.exit(1))
parser = OptionParser(usage="""
%prog [options] -o output.cmodel input.db ...
Trains models based on a cluster database.
For faster speed and better memory usage, use the "-b" option, which buffers
samples in a 1bpp buffer (only binary input patterns); however, this only works
with binary inputs and feature extractors that generate (approximately) binary
data. For example, it works for binary character images and ScaledExtractor, but not
if you use grayscale character images or StandardExtractor.
If you have lots of training data, try "-E ScaledFeatureExtractor -b", for handwriting
recognition, "-E StandardExtractor" is a better choice.
You can choose different kinds of feature extractors with the -E flag.
Some possible values are: ScaledExtractor (raw grayscale image rescaled to a target size)
and BiggestCcExtractor (biggest connected component only, otherwise treated like scaledfe).
You can find additional components by running "ocropus components" and looking for
implementors of IExtractor.
Common parameters for the model are:
-m '=ocrolib.mlp.AutoMlpModel()'
""")
parser.add_option("-o","--output",help="output model name",default=None)
parser.add_option("-m","--model",help="IModel name",default="ocrolib.mlp.AutoMlpModel()")
parser.add_option("-K","--nrounds",help="number of training rounds",type=int,default=48)
parser.add_option("-b","--bits",help="buffer training data with 1 bpp",action="store_true")
parser.add_option("-t","--table",help="database table to use for training",default="chars")
parser.add_option("-u","--unlabeled",help="treat unlabeled ('_') as reject",action="store_true")
parser.add_option("-v","--verbose",help="verbose",action="store_true")
parser.add_option("-E","--extractor",help="feature extractor",default="StandardExtractor")
parser.add_option("-N","--nvariants",help="number of variants to generate",default=0,type="int")
parser.add_option("-D","--distortion",help="maximum distortion",default=0.2,type="float")
parser.add_option("-d","--display",help="debug display",action="store_true")
parser.add_option("-Q","--maxthreads",help="max # of threads for training",type=int,default=4)
# different training modalities
parser.add_option("-g","--nogeometry",help="do not add geometric information",action="store_true")
parser.add_option("-r","--noreject",help="disable reject class",action="store_true")
parser.add_option("-1","--single",help="train only single chars, treat multiple as reject (combine with -r if you like)",action="store_true")
parser.add_option("-R","--rejectonly",help="only train reject/non-reject classifier",action="store_true")
parser.add_option("-L","--lengthpred",help="train a length predictor",action="store_true")
parser.add_option("-M","--multionly",help="only train multiple charactes",action="store_true")
# the following are for limiting training samples
parser.add_option("-n","--limit",help="limit total training to n samples",default=999999999,type="int")
parser.add_option("-T","--threshold",help="threshold for estimating ntrain (either a percentile or a class)",default="70")
parser.add_option("-F","--rejectfactor",help="multiple of per class ntrain",type=float,default=20.0)
(options,args) = parser.parse_args()
assert options.output is not None,"you must provide a -o flag"
if len(args)<1:
parser.print_help()
sys.exit(0)
mlp.maxthreads.value = options.maxthreads
mlp.maxthreads_train.value = options.maxthreads
# some utility functions we're going to need later
def pad(image,dx,dy,bgval=None):
"""Pad an image by the given amounts in the x and y directions.
If bgval is not given, the minimum of the input image is used."""
if bgval is None: bgval = amin(image)
h,w = image.shape
result = zeros((h+2*dx,w+2*dy))
result[:,:] = bgval
result[dx:-dx,dy:-dy] = image
return result
def distort(image,sx,sy,sigma=10.0):
"""Distort the image by generating smoothed noise for the
x and y displacement vectors. The displacement amounts sx and
sy are in terms of fractions of the image width and height.
Image should be an narray."""
h0,w0 = image.shape
my = sy*h0
mx = sx*w0
image = pad(image,int(my+1),int(mx+1))
h,w = image.shape
dy = filters.gaussian_filter(rand(*image.shape)-0.5,sigma)
dy *= my/amax(abs(dy))
dx = filters.gaussian_filter(rand(*image.shape)-0.5,sigma)
dx *= mx/amax(abs(dx))
dy += arange(h)[:,newaxis]
dx += arange(w)[newaxis,:]
distorted = interpolation.map_coordinates(image,array([dy,dx]),order=1)
# print amax(abs(dy)),amax(abs(dx)),amax(abs(distorted-image))
return distorted
# unlabeled-as-reject implies that we train reject classes
if options.unlabeled: options.noreject = False
if len(args)<1:
print "must specify at least one character database as argument"
sys.exit(1)
output = options.output
if output is None:
output= os.path.splitext(args[0])[0]+".cmodel"
print "output to",output
if os.path.exists(output):
print output,"already exists"
sys.exit(1)
if not ".cmodel" in output and not ".pymodel" in output:
print "output",output,"should end in .cmodel or .pymodel"
sys.exit(1)
# create a Python window if requested
if options.display:
ion()
show()
# initialize the C++ debugging graphics (just in case it's needed for
# debugging the native code feature extractors)
classifier = ocrolib.make_IModel(options.model)
print "classifier",classifier
# little utility function for displaying character training progress window
fig = 0
def show_char(image,cls):
"""Display the character in a grid, wrapping around every now and then.
Used for showing progress in loading the training characters."""
global fig
r = 4
if fig%r**2==0: clf(); gray()
subplot(r,r,fig%r**2+1)
fig += 1
imshow(image/255.0)
text(3,3,str(cls),color="red",fontsize=14)
draw()
ginput(1,timeout=0.1)
print "training..."
count = 0
nchar = 0
skip0 = 0
skip3 = 0
nreject = 0
def iterate_db(arg):
print "===",arg,"==="
class Counter(dict):
def __getitem__(self,index):
return self.setdefault(index,0)
totals = Counter()
actual = Counter()
for arg in args:
if not os.path.exists(arg):
print ""+arg+": not found"
sys.exit(1)
for arg in args:
print "===",arg
db = dbhelper.chardb(arg,options.table)
print "total",db.execute("select count(*) from chars").next()[0]
print "# determining per-class cutoff"
classes = [tuple(x) for x in db.execute("select cls,count(*) from chars group by cls")]
assert len(classes)>=2,"too few classes in database; got %s"%classes
counts= array([x[1] for x in classes],'f')
threshold = options.threshold
if re.match(r'[0-9][0-9.]+',threshold):
threshold = float(threshold)
assert len(counts)>1
assert threshold>=0 and threshold<=100
ntrain = int(stats.scoreatpercentile(counts,threshold))
else:
for k,v in classes:
if k==threshold:
ntrain = v
break
if ntrain is None:
print "class not found:",threshold
print "classes",len(counts),"stats",amin(counts),median(counts),mean(counts),amax(counts),"ntrain",ntrain
print "# sampling the classes"
all_ids = []
for k,v in classes:
ids = list(db.execute("select id from chars where cls=?",[k]))
ids = [x[0] for x in ids]
if k=="" or ord(k[0])<33: continue
if k=="_": continue
if k=="~": ntrain = int(ntrain*options.rejectfactor)
if len(ids)>ntrain: ids = pyrandom.sample(ids,ntrain)
print " %s:%d"%(k,len(ids)),; sys.stdout.flush()
all_ids += ids
print
counter = Counter()
print "# training",len(all_ids)
for id in all_ids:
c = db.execute("select image,cls,rel from chars where id=?",[id]).next()
cluster = Record(image=dbhelper.blob2image(c.image),cls=c.cls,rel=c.rel)
# check whether we already have enough characters
if count>options.limit:
break
cls = cluster.cls
if cls is None:
cls = "_"
counter[cls] += 1
totals[cls] += 1
# empty strings have no class at all, so we skip them
if len(cls)==0:
skip0 += 1
continue
# can't train transcriptions longer than three characters
if len(cls)>3:
skip3 += 1
continue
# if the user requested only single character training, skip multi-character transcriptions
if options.single and len(cls)>1: continue
# if the user requested it, treat unlabeled samples ("_") as reject classes
if options.unlabeled and cls=="_": cls = "~"
# if the user didn't want reject training, skip all reject classes
if options.noreject and cls=="~": continue
# skip any remaining unlabeled samples
if cls=="_": continue
if cls=="~": nreject += 1
# count the actually trained samples
actual[cls] += 1
geometry=None
if not options.nogeometry:
if cluster.rel is None or len(cluster.rel)<2:
print "training with geometry requested, but",arg,"lacks geometry information"
print "use the -g flag to turn off training with geometry"
sys.exit(1)
geometry = docproc.rel_geo_normalize(cluster.rel)
# add the image to the classifier
image = cluster.image/255.0
if options.display and nchar%1000==0: show_char(image,cls)
if min(image.shape)<3: continue
if amax(image)==amin(image): continue
if geometry is not None:
classifier.cadd(image,cls,geometry=geometry)
else:
classifier.cadd(image,cls)
count += 1
# if the user requested automatically generated variants,
# generate them and add them to the classifier as well
for i in range(options.nvariants):
if count>options.limit: break
sx = options.distortion
sy = options.distortion
distorted = distort(image,sx,sy)
if options.display and nchar%1000==0: show_char(distorted,cls)
try:
if geometry is not None:
classifier.cadd(image,cls,geometry=geometry)
else:
classifier.cadd(image,cls)
count += 1
except:
traceback.print_exc()
continue
nchar += 1
if nchar%10000==0: print "training",nchar,count
print
print "summary statistics:"
print sorted(list(actual.items()))
print
print "training",count,"variants representing",nchar,"training characters"
if skip0>0: print "skipped",skip0,"zero-length transcriptions"
if skip3>0: print "skipped",skip3,"transcriptions that were more than three characters long"
print
# provide some feedback about the training process
if options.verbose:
if "info" in dir(classifier):
classifier.info()
if "getExtractor" in dir(classifiers) and "info" in dir(classifier.getExtractor()):
classifier.getExtractor().info()
# now perform the actual training
# (usually, this is AutoMLP, a multi-threaded, long-running stochastic
# gradient descent training for a multi-layer perceptron)
print "starting classifier training"
round = 0
for progress in classifier.updateModel1(verbose=1):
print "ocropus-ctrain:",round,progress
modelbase,_ = os.path.splitext(options.output)
ocrolib.save_component("%s.%03d.cmodel"%(modelbase,round),classifier)
round += 1
if round>=options.nrounds: break
# provide some feedback about the training process
if options.verbose:
classifier.info()
# save the resulting model
print "saving",output
ocrolib.save_component(output,classifier)