Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: '<' not supported between instances of 'str' and 'int' when using multiclassification #2610

Closed
cia05rf opened this issue Dec 3, 2019 · 3 comments · Fixed by #2619
Assignees
Labels

Comments

@cia05rf
Copy link

cia05rf commented Dec 3, 2019

I'm running a multiclassification model and it was running fine until a recent upgrade.

It is now throwing a type error (as stated in the question). A few things to note:

  • I am using a custom loss function, but this error still occurs even when using logloss as the metric.
  • I am using a separate validation set to the training set and laying this out in 'eval_set'.

The code is:

#Build a custom loss function
def lgbm_custom_loss(_y_act,_y_pred):
    #Convert _y_pred into classes
    _y_pred_conv = []
    _n_classes = len(np.unique(_y_act))
    for i in range(0,_y_act.shape[0]):
        _tmp_li = []
        for j in range(0,_n_classes):
            _tmp_li.append(_y_pred[(_y_act.shape[0]*j) + i])
        _y_pred_conv.append(np.argmax(_tmp_li))
    _y_pred_conv = np.array(_y_pred_conv)
    _ac_results = calc_tpr(_y_pred_conv,_y_act,unique_classes=range(0,_n_classes))
    _av_ppv = _ac_results[_ac_results.opt_text.isin([0,2])]['ppv'].mean() #Only average for buy and sell
    #If _av_ppv is 0 then append the time to prevent early stopping
    if _av_ppv == 0:
        _time_now = dt.datetime.now()
        _av_ppv += _time_now.hour*10**-4 + _time_now.minute*10**-6 + _time_now.second*10**-8
    # (eval_name, eval_result, is_higher_better)
    return 'lgbm_custom_loss',_av_ppv,True

#Create parameters grid
#Create fixed parameters
mod_fixed_params = {
    'boosting_type':'gbdt'
    ,'random_state':0
    ,'silent':False
    ,'objective':'multiclass'
    ,'num_class':y_train.unique().shape[0]
    ,'min_samples_split':200 #Should be between 0.5-1% of samples
    ,'min_samples_leaf':50
    ,'subsample':0.8
}
print('mod_fixed_params -> {}'.format(mod_fixed_params))
search_params = {
    'fixed':{
        'cv':2
        ,'n_iter':1
        ,'verbose':True
        ,'random_state':0
    }
    ,'variable':{
        'learning_rate':[0.1,0.01,0.005]
        ,'num_leaves':np.linspace(10,1010,100,dtype=int)
        ,'max_depth':np.linspace(2,22,10,dtype=int)
    }
}
print('search_params -> {}'.format(search_params))
fit_params = {
    'verbose':True
    ,'eval_set':[(X_valid,y_valid)]
    ,'eval_metric':lgbm_custom_loss
    ,'early_stopping_rounds':5
}
print('fit_params -> {}'.format(fit_params))

#Setup the model
lgb_mod = lgb.LGBMClassifier(**mod_fixed_params)
#Add the search grid
gbm = RandomizedSearchCV(lgb_mod,search_params['variable'],**search_params['fixed'])
#Fit the model
gbm.fit(X_train,y_train,**fit_params)
print('Best parameters found by grid search are: {}'.format(gbm.best_params_))

Packages are:
numpy==1.17.0
sklearn==0.19.2
lightgbm==2.3.1

@cia05rf
Copy link
Author

cia05rf commented Dec 3, 2019

The issue seems to be in sklearn.preprocessing -> transform. Here it compares the self._classes and classes, but for some reason classes is still encoded. I think this has something to to do with the fact that it's using a different set of data for validation than it is for training.

I reckon somewhere it's encoding the y_valid, then passing it to sklearn.preprocessing -> transform where it is not being encoded back again to match the self._classes.

error trace is:

Fitting 3 folds for each of 80 candidates, totalling 240 fits
[1]	valid_0's multi_logloss: 0.967121	valid_0's lgbm_custom_loss: 0.00210429
Training until validation scores don't improve for 5 rounds
[2]	valid_0's multi_logloss: 0.960505	valid_0's lgbm_custom_loss: 0.00210431
[3]	valid_0's multi_logloss: 0.95471	valid_0's lgbm_custom_loss: 0.00210433
[4]	valid_0's multi_logloss: 0.949429	valid_0's lgbm_custom_loss: 0.00210436
[5]	valid_0's multi_logloss: 0.944814	valid_0's lgbm_custom_loss: 0.00210438
[6]	valid_0's multi_logloss: 0.940716	valid_0's lgbm_custom_loss: 0.166667
[7]	valid_0's multi_logloss: 0.936959	valid_0's lgbm_custom_loss: 0.1
[8]	valid_0's multi_logloss: 0.933548	valid_0's lgbm_custom_loss: 0.636364
[9]	valid_0's multi_logloss: 0.930506	valid_0's lgbm_custom_loss: 0.512987
[10]	valid_0's multi_logloss: 0.927824	valid_0's lgbm_custom_loss: 0.46119
[11]	valid_0's multi_logloss: 0.92536	valid_0's lgbm_custom_loss: 0.5
[12]	valid_0's multi_logloss: 0.923067	valid_0's lgbm_custom_loss: 0.494774
[13]	valid_0's multi_logloss: 0.92097	valid_0's lgbm_custom_loss: 0.514035
Early stopping, best iteration is:
[8]	valid_0's multi_logloss: 0.933548	valid_0's lgbm_custom_loss: 0.636364
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-17-99e568e4a059> in <module>()
      6 gbm = RandomizedSearchCV(lgb_mod,search_params['variable'],**search_params['fixed'])
      7 #Fit the model
----> 8 gbm.fit(X_train,y_train,**fit_params)
      9 print('Best parameters found by grid search are: {}'.format(gbm.best_params_))
     10 run_time.end()

~\Anaconda3\lib\site-packages\sklearn\model_selection\_search.py in fit(self, X, y, groups, **fit_params)
    638                                   error_score=self.error_score)
    639           for parameters, (train, test) in product(candidate_params,
--> 640                                                    cv.split(X, y, groups)))
    641 
    642         # if one choose to see train score, "out" will contain train score info

~\Anaconda3\lib\site-packages\sklearn\externals\joblib\parallel.py in __call__(self, iterable)
    777             # was dispatched. In particular this covers the edge
    778             # case of Parallel used with an exhausted iterator.
--> 779             while self.dispatch_one_batch(iterator):
    780                 self._iterating = True
    781             else:

~\Anaconda3\lib\site-packages\sklearn\externals\joblib\parallel.py in dispatch_one_batch(self, iterator)
    623                 return False
    624             else:
--> 625                 self._dispatch(tasks)
    626                 return True
    627 

~\Anaconda3\lib\site-packages\sklearn\externals\joblib\parallel.py in _dispatch(self, batch)
    586         dispatch_timestamp = time.time()
    587         cb = BatchCompletionCallBack(dispatch_timestamp, len(batch), self)
--> 588         job = self._backend.apply_async(batch, callback=cb)
    589         self._jobs.append(job)
    590 

~\Anaconda3\lib\site-packages\sklearn\externals\joblib\_parallel_backends.py in apply_async(self, func, callback)
    109     def apply_async(self, func, callback=None):
    110         """Schedule a func to be run"""
--> 111         result = ImmediateResult(func)
    112         if callback:
    113             callback(result)

~\Anaconda3\lib\site-packages\sklearn\externals\joblib\_parallel_backends.py in __init__(self, batch)
    330         # Don't delay the application, to avoid keeping the input
    331         # arguments in memory
--> 332         self.results = batch()
    333 
    334     def get(self):

~\Anaconda3\lib\site-packages\sklearn\externals\joblib\parallel.py in __call__(self)
    129 
    130     def __call__(self):
--> 131         return [func(*args, **kwargs) for func, args, kwargs in self.items]
    132 
    133     def __len__(self):

~\Anaconda3\lib\site-packages\sklearn\externals\joblib\parallel.py in <listcomp>(.0)
    129 
    130     def __call__(self):
--> 131         return [func(*args, **kwargs) for func, args, kwargs in self.items]
    132 
    133     def __len__(self):

~\Anaconda3\lib\site-packages\sklearn\model_selection\_validation.py in _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters, return_n_test_samples, return_times, error_score)
    456             estimator.fit(X_train, **fit_params)
    457         else:
--> 458             estimator.fit(X_train, y_train, **fit_params)
    459 
    460     except Exception as e:

~\Anaconda3\lib\site-packages\lightgbm\sklearn.py in fit(self, X, y, sample_weight, init_score, eval_set, eval_names, eval_sample_weight, eval_class_weight, eval_init_score, eval_metric, early_stopping_rounds, verbose, feature_name, categorical_feature, callbacks)
    791                     eval_set[i] = (valid_x, _y)
    792                 else:
--> 793                     eval_set[i] = (valid_x, self._le.transform(valid_y))
    794 
    795         super(LGBMClassifier, self).fit(X, _y, sample_weight=sample_weight,

~\Anaconda3\lib\site-packages\sklearn\preprocessing\label.py in transform(self, y)
    129         y = column_or_1d(y, warn=True)
    130         classes = np.unique(y)
--> 131         if len(np.intersect1d(classes, self.classes_)) < len(classes):
    132             diff = np.setdiff1d(classes, self.classes_)
    133             raise ValueError("y contains new labels: %s" % str(diff))

<__array_function__ internals> in intersect1d(*args, **kwargs)

~\Anaconda3\lib\site-packages\numpy\lib\arraysetops.py in intersect1d(ar1, ar2, assume_unique, return_indices)
    413         aux = aux[aux_sort_indices]
    414     else:
--> 415         aux.sort()
    416 
    417     mask = aux[1:] == aux[:-1]

TypeError: '<' not supported between instances of 'str' and 'int'

@jameslamb jameslamb added the bug label Dec 3, 2019
@StrikerRUS
Copy link
Collaborator

I cannot reproduce the error. Maybe you can try a newer version of scikit-learn? Probably they have fixed something. Here is the modified script I run locally

import numpy as np
import lightgbm as lgb

from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split, RandomizedSearchCV

#Build a custom loss function
def lgbm_custom_loss(y_true , y_pred):
    return 'lgbm_custom_loss', 1, True

X, y = load_digits(4, True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.1, random_state=42)

#Create parameters grid
#Create fixed parameters
mod_fixed_params = {
    'boosting_type':'gbdt'
    ,'random_state':0
    ,'silent':False
    ,'objective':'multiclass'
    ,'num_class':np.unique(y_train)
    ,'min_samples_split':200 #Should be between 0.5-1% of samples
    ,'min_samples_leaf':50
    ,'subsample':0.8
}

search_params = {
    'fixed':{
        'cv':2
        ,'n_iter':1
        ,'verbose':True
        ,'random_state':0
    }
    ,'variable':{
        'learning_rate':[0.1,0.01,0.005]
        ,'num_leaves':np.linspace(10,1010,100,dtype=int)
        ,'max_depth':np.linspace(2,22,10,dtype=int)
    }
}

fit_params = {
    'verbose':True
    ,'eval_set':[(X_valid,y_valid)]
    ,'eval_metric':lgbm_custom_loss
    ,'early_stopping_rounds':5
}


#Setup the model
lgb_mod = lgb.LGBMClassifier(**mod_fixed_params)
#Add the search grid
gbm = RandomizedSearchCV(lgb_mod,search_params['variable'],**search_params['fixed'])
#Fit the model
gbm.fit(X_train,y_train,**fit_params)
print('Best parameters found by grid search are: {}'.format(gbm.best_params_))

and the diff

@@ -1,22 +1,15 @@
+import numpy as np
+import lightgbm as lgb
+
+from sklearn.datasets import load_digits
+from sklearn.model_selection import train_test_split, RandomizedSearchCV
+
 #Build a custom loss function
-def lgbm_custom_loss(_y_act,_y_pred):
-    #Convert _y_pred into classes
-    _y_pred_conv = []
-    _n_classes = len(np.unique(_y_act))
-    for i in range(0,_y_act.shape[0]):
-        _tmp_li = []
-        for j in range(0,_n_classes):
-            _tmp_li.append(_y_pred[(_y_act.shape[0]*j) + i])
-        _y_pred_conv.append(np.argmax(_tmp_li))
-    _y_pred_conv = np.array(_y_pred_conv)
-    _ac_results = calc_tpr(_y_pred_conv,_y_act,unique_classes=range(0,_n_classes))
-    _av_ppv = _ac_results[_ac_results.opt_text.isin([0,2])]['ppv'].mean() #Only average for buy and sell
-    #If _av_ppv is 0 then append the time to prevent early stopping
-    if _av_ppv == 0:
-        _time_now = dt.datetime.now()
-        _av_ppv += _time_now.hour*10**-4 + _time_now.minute*10**-6 + _time_now.second*10**-8
-    # (eval_name, eval_result, is_higher_better)
-    return 'lgbm_custom_loss',_av_ppv,True
+def lgbm_custom_loss(y_true , y_pred):
+    return 'lgbm_custom_loss', 1, True
+
+X, y = load_digits(4, True)
+X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.1, random_state=42)

 #Create parameters grid
 #Create fixed parameters
@@ -25,12 +18,12 @@ mod_fixed_params = {
     ,'random_state':0
     ,'silent':False
     ,'objective':'multiclass'
-    ,'num_class':y_train.unique().shape[0]
+    ,'num_class':np.unique(y_train)
     ,'min_samples_split':200 #Should be between 0.5-1% of samples
     ,'min_samples_leaf':50
     ,'subsample':0.8
 }
-print('mod_fixed_params -> {}'.format(mod_fixed_params))
+
 search_params = {
     'fixed':{
         'cv':2
@@ -44,14 +37,14 @@ search_params = {
         ,'max_depth':np.linspace(2,22,10,dtype=int)
     }
 }
-print('search_params -> {}'.format(search_params))
+
 fit_params = {
     'verbose':True
     ,'eval_set':[(X_valid,y_valid)]
     ,'eval_metric':lgbm_custom_loss
     ,'early_stopping_rounds':5
 }
-print('fit_params -> {}'.format(fit_params))
+

 #Setup the model
 lgb_mod = lgb.LGBMClassifier(**mod_fixed_params)

It run just fine

Fitting 2 folds for each of 1 candidates, totalling 2 fits
[1]	valid_0's multi_logloss: 1.38043	valid_0's lgbm_custom_loss: 1
Training until validation scores don't improve for 5 rounds
[2]	valid_0's multi_logloss: 1.37246	valid_0's lgbm_custom_loss: 1
[3]	valid_0's multi_logloss: 1.36447	valid_0's lgbm_custom_loss: 1
[4]	valid_0's multi_logloss: 1.35664	valid_0's lgbm_custom_loss: 1
[5]	valid_0's multi_logloss: 1.3488	valid_0's lgbm_custom_loss: 1
[6]	valid_0's multi_logloss: 1.34114	valid_0's lgbm_custom_loss: 1
Early stopping, best iteration is:
[1]	valid_0's multi_logloss: 1.38043	valid_0's lgbm_custom_loss: 1
[1]	valid_0's multi_logloss: 1.3813	valid_0's lgbm_custom_loss: 1
Training until validation scores don't improve for 5 rounds
[2]	valid_0's multi_logloss: 1.37327	valid_0's lgbm_custom_loss: 1
[3]	valid_0's multi_logloss: 1.36538	valid_0's lgbm_custom_loss: 1
[4]	valid_0's multi_logloss: 1.35768	valid_0's lgbm_custom_loss: 1
[5]	valid_0's multi_logloss: 1.34994	valid_0's lgbm_custom_loss: 1
[6]	valid_0's multi_logloss: 1.34237	valid_0's lgbm_custom_loss: 1
Early stopping, best iteration is:
[1]	valid_0's multi_logloss: 1.3813	valid_0's lgbm_custom_loss: 1
[1]	valid_0's multi_logloss: 1.38039	valid_0's lgbm_custom_loss: 1
Training until validation scores don't improve for 5 rounds
[2]	valid_0's multi_logloss: 1.37198	valid_0's lgbm_custom_loss: 1

[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.
[Parallel(n_jobs=1)]: Done   2 out of   2 | elapsed:    0.0s finished


[3]	valid_0's multi_logloss: 1.36366	valid_0's lgbm_custom_loss: 1
[4]	valid_0's multi_logloss: 1.35537	valid_0's lgbm_custom_loss: 1
[5]	valid_0's multi_logloss: 1.34717	valid_0's lgbm_custom_loss: 1
[6]	valid_0's multi_logloss: 1.33908	valid_0's lgbm_custom_loss: 1
Early stopping, best iteration is:
[1]	valid_0's multi_logloss: 1.38039	valid_0's lgbm_custom_loss: 1


Best parameters found by grid search are: {'num_leaves': 333, 'max_depth': 17, 'learning_rate': 0.005}

Packages are:
numpy==1.16.2
sklearn==0.20.3
lightgbm==2.3.1

@StrikerRUS
Copy link
Collaborator

Seems that I've managed to reproduce the error. However, I get a different error message. I guess it's due to scikit-learn version mismatch.

The error can be observed only with non-trivial ys. For example,

...
X, y = load_digits(4, True)
y = np.array(list(map(str, y)))
...
ValueError                                Traceback (most recent call last)
<ipython-input-2-8374d53e3ed0> in <module>
     54 gbm = RandomizedSearchCV(lgb_mod,search_params['variable'],**search_params['fixed'])
     55 #Fit the model
---> 56 gbm.fit(X_train,y_train,**fit_params)
     57 print('Best parameters found by grid search are: {}'.format(gbm.best_params_))

D:\Miniconda3\lib\site-packages\sklearn\model_selection\_search.py in fit(self, X, y, groups, **fit_params)
    720                 return results_container[0]
    721 
--> 722             self._run_search(evaluate_candidates)
    723 
    724         results = results_container[0]

D:\Miniconda3\lib\site-packages\sklearn\model_selection\_search.py in _run_search(self, evaluate_candidates)
   1513         evaluate_candidates(ParameterSampler(
   1514             self.param_distributions, self.n_iter,
-> 1515             random_state=self.random_state))

D:\Miniconda3\lib\site-packages\sklearn\model_selection\_search.py in evaluate_candidates(candidate_params)
    709                                for parameters, (train, test)
    710                                in product(candidate_params,
--> 711                                           cv.split(X, y, groups)))
    712 
    713                 all_candidate_params.extend(candidate_params)

D:\Miniconda3\lib\site-packages\sklearn\externals\joblib\parallel.py in __call__(self, iterable)
    918                 self._iterating = self._original_iterator is not None
    919 
--> 920             while self.dispatch_one_batch(iterator):
    921                 pass
    922 

D:\Miniconda3\lib\site-packages\sklearn\externals\joblib\parallel.py in dispatch_one_batch(self, iterator)
    757                 return False
    758             else:
--> 759                 self._dispatch(tasks)
    760                 return True
    761 

D:\Miniconda3\lib\site-packages\sklearn\externals\joblib\parallel.py in _dispatch(self, batch)
    714         with self._lock:
    715             job_idx = len(self._jobs)
--> 716             job = self._backend.apply_async(batch, callback=cb)
    717             # A job can complete so quickly than its callback is
    718             # called before we get here, causing self._jobs to

D:\Miniconda3\lib\site-packages\sklearn\externals\joblib\_parallel_backends.py in apply_async(self, func, callback)
    180     def apply_async(self, func, callback=None):
    181         """Schedule a func to be run"""
--> 182         result = ImmediateResult(func)
    183         if callback:
    184             callback(result)

D:\Miniconda3\lib\site-packages\sklearn\externals\joblib\_parallel_backends.py in __init__(self, batch)
    547         # Don't delay the application, to avoid keeping the input
    548         # arguments in memory
--> 549         self.results = batch()
    550 
    551     def get(self):

D:\Miniconda3\lib\site-packages\sklearn\externals\joblib\parallel.py in __call__(self)
    223         with parallel_backend(self._backend, n_jobs=self._n_jobs):
    224             return [func(*args, **kwargs)
--> 225                     for func, args, kwargs in self.items]
    226 
    227     def __len__(self):

D:\Miniconda3\lib\site-packages\sklearn\externals\joblib\parallel.py in <listcomp>(.0)
    223         with parallel_backend(self._backend, n_jobs=self._n_jobs):
    224             return [func(*args, **kwargs)
--> 225                     for func, args, kwargs in self.items]
    226 
    227     def __len__(self):

D:\Miniconda3\lib\site-packages\sklearn\model_selection\_validation.py in _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, fit_params, return_train_score, return_parameters, return_n_test_samples, return_times, return_estimator, error_score)
    526             estimator.fit(X_train, **fit_params)
    527         else:
--> 528             estimator.fit(X_train, y_train, **fit_params)
    529 
    530     except Exception as e:

D:\Miniconda3\lib\site-packages\lightgbm\sklearn.py in fit(self, X, y, sample_weight, init_score, eval_set, eval_names, eval_sample_weight, eval_class_weight, eval_init_score, eval_metric, early_stopping_rounds, verbose, feature_name, categorical_feature, callbacks)
    793                     eval_set[i] = (valid_x, _y)
    794                 else:
--> 795                     eval_set[i] = (valid_x, self._le.transform(valid_y))
    796 
    797         super(LGBMClassifier, self).fit(X, _y, sample_weight=sample_weight,

D:\Miniconda3\lib\site-packages\sklearn\preprocessing\label.py in transform(self, y)
    255             return np.array([])
    256 
--> 257         _, y = _encode(y, uniques=self.classes_, encode=True)
    258         return y
    259 

D:\Miniconda3\lib\site-packages\sklearn\preprocessing\label.py in _encode(values, uniques, encode)
    108         return _encode_python(values, uniques, encode)
    109     else:
--> 110         return _encode_numpy(values, uniques, encode)
    111 
    112 

D:\Miniconda3\lib\site-packages\sklearn\preprocessing\label.py in _encode_numpy(values, uniques, encode)
     51         if diff:
     52             raise ValueError("y contains previously unseen labels: %s"
---> 53                              % str(diff))
     54         encoded = np.searchsorted(uniques, values)
     55         return uniques, encoded

ValueError: y contains previously unseen labels: [0, 1, 2, 3]

The problem is in a list's mutability

if eval_set is not None:
if isinstance(eval_set, tuple):
eval_set = [eval_set]
for i, (valid_x, valid_y) in enumerate(eval_set):
if valid_x is X and valid_y is y:
eval_set[i] = (valid_x, _y)
else:
eval_set[i] = (valid_x, self._le.transform(valid_y))

I'll try to fix it.

@cia05rf If you are in a hurry, you can try the following workaround:

,'eval_set':[(X_valid,y_valid)] -> ,'eval_set':(X_valid,y_valid)

@StrikerRUS StrikerRUS self-assigned this Dec 5, 2019
@lock lock bot locked as resolved and limited conversation to collaborators Mar 10, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants