Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Multi-class selector.

  • Loading branch information...
commit 579ae691d07a458826ef5aad996dc81c10f082cc 1 parent 30b90ba
@kvh authored
Showing with 31 additions and 8 deletions.
  1. +5 −2 ramp/dataset.py
  2. +26 −6 ramp/selectors.py
View
7 ramp/dataset.py
@@ -142,7 +142,7 @@ def get_validation_data(self):
def get_saved_configurations(self, **kwargs):
try:
- saved = self.load('saved_scores')
+ saved = self.store.load('saved_scores')
except KeyError:
saved = []
if kwargs:
@@ -153,7 +153,10 @@ def save_configurations(self, configs):
"""configs: list of (cv_scores, config) """
saved = self.get_saved_configurations()
saved.extend(configs)
- self.dump('saved_scores', saved) #saved[:self.keep_nmodels])
+ self.store.save('saved_scores', saved) #saved[:self.keep_nmodels])
+
+ def save_config(self, config, scores):
+ self.save_configurations((scores, copy.copy(config)))
# def get_best_configuration(self, rank=0, weight_func=None, **kwargs):
# saved = self.get_saved_configurations()
View
32 ramp/selectors.py
@@ -146,7 +146,7 @@ def sets(self, x, y, n_keep):
class BinaryFeatureSelector(Selector):
- """ Only for binary classification and binary(-able) features """
+ """ Only for classification and binary(-able) features """
def __init__(self, type='bns', *args, **kwargs):
""" type in ('bns', 'acc')
@@ -156,8 +156,31 @@ def __init__(self, type='bns', *args, **kwargs):
def sets(self, x, y, n_keep):
cnts = y.value_counts()
- assert(len(cnts) == 2)
print "Computing binary feature scores for %d features..." % len(x.columns)
+ if len(cnts) > 2:
+ scores = self.round_robin(x, y, n_keep)
+ else:
+ scores = self.rank(x, y)
+ if self.verbose:
+ # just show top few hundred
+ print scores[:200]
+ return [s[1] for s in scores[:n_keep]]
+
+ def round_robin(self, x, y, n_keep):
+ """ Ensures all classes get representative features, not just those with strong features """
+ vals = y.unique()
+ scores = {}
+ for cls in vals:
+ scores[cls] = self.rank(x, np.equal(cls, y).astype('Int64'))
+ scores[cls].reverse()
+ keepers = []
+ while len(keepers) < n_keep:
+ for cls in vals:
+ keepers.append(scores[cls].pop())
+ return keepers
+
+ def rank(self, x, y):
+ cnts = y.value_counts()
scores = []
for c in x.columns:
true_positives = np.count_nonzero(np.logical_and(x[c], y))
@@ -172,10 +195,7 @@ def sets(self, x, y, n_keep):
score = abs(tpr - fpr)
scores.append((score, c))
scores.sort(reverse=True)
- if self.verbose:
- # just show top few hundred
- print scores[:200]
- return [s[1] for s in scores[:n_keep]]
+ return scores
class InformationGainSelector(Selector):
""" Only for binary classification """

0 comments on commit 579ae69

Please sign in to comment.
Something went wrong with that request. Please try again.