diff --git a/few/few.py b/few/few.py index a41bd23..7601c63 100644 --- a/few/few.py +++ b/few/few.py @@ -53,7 +53,7 @@ def __init__(self, population_size=50, generations=100, mutation_rate=0.5, crossover_rate=0.5, ml = None, min_depth = 1, max_depth = 2, max_depth_init = 2, sel = 'epsilon_lexicase', tourn_size = 2, fit_choice = None, - op_weight = False, seed_with_ml = True, erc = False, + op_weight = False, max_stall=10, seed_with_ml = True, erc = False, random_state=np.random.randint(9999999), verbosity=0, scoring_function=None, disable_update_check=False, elitism=True, boolean = False,classification=False,clean=False, @@ -90,6 +90,7 @@ def __init__(self, population_size=50, generations=100, self.tourn_size = tourn_size self.fit_choice = fit_choice self.op_weight = op_weight + self.max_stall = max_stall self.seed_with_ml = seed_with_ml self.erc = erc self.random_state = random_state @@ -272,8 +273,11 @@ def fit(self, features, labels): # progress bar pbar = tqdm(total=self.generations,disable = self.verbosity==0, desc='Internal CV: {:1.3f}'.format(self._best_score)) + stall_count = 0 # for each generation g for g in np.arange(self.generations): + if stall_count == self.max_stall: + break; if self.track_diversity: self.get_diversity(self.X) @@ -344,9 +348,12 @@ def fit(self, features, labels): if self.valid_loc() and tmp_score > self._best_score: self._best_estimator = copy.deepcopy(self.ml) self._best_score = tmp_score + stall_count = 0; self._best_inds = copy.deepcopy(self.valid()) if self.verbosity > 1: print("updated best internal CV:",self._best_score) + else: + stall_count = stall_count + 1 # Variation if self.verbosity > 2: @@ -808,6 +815,10 @@ def main(): type=bool, help='Weight attributes for incuded in' ' features based on ML scores. Default: off') + parser.add_argument('-ms', action='store', dest='MAX_STALL',default=10, + help='The number of iterations to do when the best' + ' score is not improving before taking the best score value') + parser.add_argument('-sel', action='store', dest='SEL', default='epsilon_lexicase', choices = ['tournament','lexicase','epsilon_lexicase', @@ -928,6 +939,7 @@ def main(): min_depth = args.MIN_DEPTH,max_depth = args.MAX_DEPTH, sel = args.SEL, tourn_size = args.TOURN_SIZE, seed_with_ml = args.SEED_WITH_ML, op_weight = args.OP_WEIGHT, + max_stall = args.MAX_STALL, erc = args.ERC, random_state=args.RANDOM_STATE, verbosity=args.VERBOSITY, disable_update_check=args.DISABLE_UPDATE_CHECK,