Permalink
Browse files

ENH Tree strategy for multi-class learning

  • Loading branch information...
1 parent 3a464cd commit 4396f5b1fef55a3eecc6f753bf092c653ecbc55b @luispedro committed Mar 27, 2012
Showing with 89 additions and 0 deletions.
  1. +1 −0 ChangeLog
  2. +71 −0 milk/supervised/multi.py
  3. +17 −0 milk/tests/test_multi.py
View
@@ -1,6 +1,7 @@
Version 0.4.2+
* Add select_n_best & rank_corr to featureselection
* Add Euclidean MDS
+ * Add tree multi-class strategy
Version 0.4.2 2012-01-16 by luispedro
* Make defaultlearner able to take extra arguments
View
@@ -14,6 +14,7 @@
'one_against_one',
'one_against_rest_multi',
'ecoc_learner',
+ 'multi_tree_learner',
]
def _asanyarray(f):
@@ -244,3 +245,73 @@ def train(self, features, labels, normalisedlabels=False, **kwargs):
models.append(self.base.train(features, nlabels, normalisedlabels=True, **kwargs))
return ecoc_model(models, codes, self.probability)
+
+def split(counts):
+ groups = ([],[])
+ weights = np.zeros(2, float)
+
+ in_order = counts.argsort()
+ for s in in_order[::-1]:
+ g = weights.argmin()
+ groups[g].append(s)
+ weights[g] += counts[s]
+ return groups
+
+
+class multi_tree_model(supervised_model):
+ def __init__(self, model):
+ self.model = model
+
+ def apply(self, feats):
+ def ap_recursive(smodel):
+ if len(smodel) == 1:
+ return smodel[0]
+ model,left,right = smodel
+ if model.apply(feats): return ap_recursive(left)
+ else: return ap_recursive(right)
+ return ap_recursive(self.model)
+
+class multi_tree_learner(base_adaptor):
+ '''
+ Implements a multi-class learner as a tree of binary decisions.
+
+ At each level, labels are split into 2 groups in a way that attempt to
+ balance the number of examples on each side (and not the number of labels
+ on each side). This mean that on a 4 class problem with a distribution like
+ [ 50% 25% 12.5% 12.5%], the "optimal" splits are
+
+ o
+ / \
+ / \
+ [0] o
+ / \
+ [1] o
+ / \
+ [2][3]
+
+ where all comparisons are perfectly balanced.
+ '''
+
+ def train(self, features, labels, normalisedlabels=False, **kwargs):
+ if not normalisedlabels:
+ labels,names = normaliselabels(labels)
+ labelset = np.arange(len(names))
+ else:
+ labels = np.asanyarray(labels)
+ labelset = np.arange(labels.max()+1)
+
+
+ def recursive(labelset, counts):
+ if len(labelset) == 1:
+ return labelset
+ g0,g1 = split(counts)
+ nlabels = np.array([(ell in g0) for ell in labels], int)
+ model = self.base.train(features, nlabels, normaliselabels=True, **kwargs)
+ m0 = recursive(labelset[g0], counts[g0])
+ m1 = recursive(labelset[g1], counts[g1])
+ return (model, m0, m1)
+ counts = np.zeros(labels.max()+1)
+ for ell in labels:
+ counts[ell] += 1
+ return multi_tree_model(recursive(np.arange(labels.max()+1), counts))
+
View
@@ -59,3 +59,20 @@ def test_classifier_no_set_options():
milk.supervised.multi.one_against_rest(fast_classifier())
milk.supervised.multi.one_against_one(fast_classifier())
+
+def test_tree():
+ mtree = milk.supervised.multi.multi_tree_learner(fast_classifier())
+ labels = [0,1,2,2,3,3,3,3]
+ features = np.random.random_sample((len(labels), 8))
+ model = mtree.train(features, labels)
+ counts = np.zeros(4)
+ for ell in labels:
+ counts[ell] += 1
+
+ g0,g1 = milk.supervised.multi.split(counts)
+ assert np.all(g0 == [3]) or np.all(g1 == [3])
+ def r(m):
+ if len(m) == 1: return int(m[0])
+ else: return sorted([r(m[1]), r(m[2])])
+ assert r(model.model) == [3,[2,[0,1]]]
+

0 comments on commit 4396f5b

Please sign in to comment.