-
Notifications
You must be signed in to change notification settings - Fork 588
/
ocroex-gmmtree-test
executable file
·73 lines (62 loc) · 1.92 KB
/
ocroex-gmmtree-test
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
#!/usr/bin/python
import code,pickle,sys,os,re,cPickle
from pylab import *
from optparse import OptionParser
import traceback
from ocrolib import dbtables,improc,gmmtree
reload(gmmtree)
ion()
clf()
print "starting"
n = 100000
# load data if it's not already loaded (convenient for Emacs interaction)
if "data" not in dir() or len(data)!=n:
dbfile = "chars-1.db"
table = dbtables.Table(dbfile,"chars")
table.verbose = 1
table.converter("image",dbtables.SmallImage())
table.create(image="blob",count="integer",cls="text",classes="text",key="text")
rows = table.get(random_=0)
r = 32
data = zeros((n,r*r),'f')
values = [None]*n
i = 0
for row in rows:
if i>=n: break
cls = row.cls
if cls is None: continue
if len(cls)!=1: continue
if cls=="~" or cls=="_": continue
if i%1000==0: print i
image = array(row.image,'f')
image /= amax(image)
image = improc.center_maxsize(image,32)
image /= sqrt(sum(image**2))
data[i,:] = image[:].ravel()
values[i] = cls
i += 1
# tree = gmmtree.GmmTree()
# tree.build(data,values)
# means,sigmas = gmmtree.gmm_em(data,int(sqrt(len(data))))
k = 25
all_means = []
all_sigmas = []
for c in sorted(list(set(values))):
indexes = array([i for i in range(len(values)) if values[i]==c],'i')
rows = data[indexes,:]
print "===",c,len(rows)
if len(rows)<1000: continue
means,sigmas,counts = gmmtree.gmm_em(rows,k)
counts = array(counts,'i')
order = argsort(-array(counts))
means = means[order,:]
sigmas = sigmas[order,:]
counts = counts[order]
print counts
print "outliers",len(rows)-sum(counts),"::",len(rows),sum(counts)
all_means.append(means)
all_sigmas.append(sigmas)
gmmtree.showgrid3(sigmas,sigmas,means,d=int(sqrt(k)))
ginput(1,timeout=2.0)
all_means = row_stack(all_means)
all_sigmas = row_stack(all_sigmas)