In [1]:
import numpy as np
import lightgbm as lgb
from wideboost.wrappers import wlgb

import tensorflow_datasets as tfds
from matplotlib import pyplot as plt

(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

for i in ds_train.batch(60000):
    a = i
    break
    
for i in ds_test.batch(60000):
    b = i
    break

In [2]:
xtrain = a[0].numpy().reshape([-1,28*28])
ytrain = a[1].numpy()

xtest = b[0].numpy().reshape([-1,28*28])
ytest = b[1].numpy()

#dtrain = xgb.DMatrix(xtrain,label=ytrain)
#dtest = xgb.DMatrix(xtest,label=ytest)

train_data = lgb.Dataset(xtrain, label=ytrain)
test_data = lgb.Dataset(xtest, label=ytest)

In [3]:
from hyperopt import fmin, tpe, hp, STATUS_OK, space_eval

best_val = 1.0

def objective(param):
    global best_val
    #watchlist = [(dtrain,'train'),(dtest,'test')]
    ed1_results = dict()
    print(param)
    param['num_leaves'] = round(param['num_leaves']+1)
    param['min_data_in_leaf'] = round(param['min_data_in_leaf'])
    wbst = wlgb.train(param,
                      train_data,
                      num_boost_round=20,
                      valid_sets=test_data,
                      evals_result=ed1_results)
    output = min(ed1_results['valid_0']['multi_error'])
    
    if output < best_val:
        print("NEW BEST VALUE!")
        best_val = output
    
    return {'loss': output, 'status': STATUS_OK }

spc = {
    'btype': hp.choice('btype',['R','I','Rn','In']),
    'extra_dims': hp.choice('extra_dims',range(16)),
    'objective': hp.choice('objective',['multiclass']),
    'metric':hp.choice('metric',['multi_error']),
    'num_class':hp.choice('num_class',[10]),
    'learning_rate': hp.loguniform('learning_rate', -7, 0),
    'num_leaves' : hp.qloguniform('num_leaves', 0, 7, 1),
    'feature_fraction': hp.uniform('feature_fraction', 0.5, 1),
    'bagging_fraction': hp.uniform('bagging_fraction', 0.5, 1),
    'min_data_in_leaf': hp.qloguniform('min_data_in_leaf', 0, 6, 1),
    'min_sum_hessian_in_leaf': hp.loguniform('min_sum_hessian_in_leaf', -16, 5),
    'lambda_l1': hp.choice('lambda_l1', [0, hp.loguniform('lambda_l1_positive', -16, 2)]),
    'lambda_l2': hp.choice('lambda_l2', [0, hp.loguniform('lambda_l2_positive', -16, 2)])
}


best = fmin(objective,
    space=spc,
    algo=tpe.suggest,
    max_evals=100)

{'bagging_fraction': 0.6236568327389121, 'btype': 'Rn', 'extra_dims': 9, 'feature_fraction': 0.8664406004424731, 'lambda_l1': 0, 'lambda_l2': 0, 'learning_rate': 0.023147449563030463, 'metric': 'multi_error', 'min_data_in_leaf': 62.0, 'min_sum_hessian_in_leaf': 7.086725693731613e-07, 'num_class': 10, 'num_leaves': 21.0, 'objective': 'multiclass'}
Overwriting param `num_class`                          
Overwriting param `objective` while setting `fobj` in train.
Moving param `metric` to an feval.                     
[1]	valid_0's multi_error: 0.1351                      
[2]	valid_0's multi_error: 0.1201                      
[3]	valid_0's multi_error: 0.1116                      
[4]	valid_0's multi_error: 0.1078                      
[5]	valid_0's multi_error: 0.1068                      
[6]	valid_0's multi_error: 0.1051                      
[7]	valid_0's multi_error: 0.1031                      
[8]	valid_0's multi_error: 0.1012                      
[9]	valid_0's multi_error: 0.0

Overwriting param `objective` while setting `fobj` in train.         
Moving param `metric` to an feval.                                   
[1]	valid_0's multi_error: 0.0647                                    
[2]	valid_0's multi_error: 0.0486                                    
[3]	valid_0's multi_error: 0.0405                                    
[4]	valid_0's multi_error: 0.0353                                    
[5]	valid_0's multi_error: 0.0327                                    
[6]	valid_0's multi_error: 0.032                                     
[7]	valid_0's multi_error: 0.0302                                    
[8]	valid_0's multi_error: 0.029                                     
[9]	valid_0's multi_error: 0.0287                                    
[10]	valid_0's multi_error: 0.0279                                   
[11]	valid_0's multi_error: 0.0275                                   
[12]	valid_0's multi_error: 0.0272                                   
[13]	valid_0's multi

[1]	valid_0's multi_error: 0.1225                                       
[2]	valid_0's multi_error: 0.3712                                       
[3]	valid_0's multi_error: 0.3282                                       
[4]	valid_0's multi_error: 0.6758                                       
[5]	valid_0's multi_error: 0.74                                         
[6]	valid_0's multi_error: 0.8245                                       
[7]	valid_0's multi_error: 0.7126                                       
[8]	valid_0's multi_error: 0.6185                                       
[9]	valid_0's multi_error: 0.5664                                       
[10]	valid_0's multi_error: 0.5201                                      
[11]	valid_0's multi_error: 0.4917                                      
[12]	valid_0's multi_error: 0.4532                                      
[13]	valid_0's multi_error: 0.422                                       
[14]	valid_0's multi_error: 0.4099                 

[1]	valid_0's multi_error: 0.5041                                       
[2]	valid_0's multi_error: 0.4386                                       
[3]	valid_0's multi_error: 0.4195                                       
[4]	valid_0's multi_error: 0.3846                                       
[5]	valid_0's multi_error: 0.372                                        
[6]	valid_0's multi_error: 0.3497                                       
[7]	valid_0's multi_error: 0.3311                                       
[8]	valid_0's multi_error: 0.3169                                       
[9]	valid_0's multi_error: 0.3099                                       
[10]	valid_0's multi_error: 0.3018                                      
[11]	valid_0's multi_error: 0.293                                       
[12]	valid_0's multi_error: 0.2871                                      
[13]	valid_0's multi_error: 0.2772                                      
[14]	valid_0's multi_error: 0.2665                 

[2]	valid_0's multi_error: 0.3081                                     
[3]	valid_0's multi_error: 0.2673                                     
[4]	valid_0's multi_error: 0.3138                                     
[5]	valid_0's multi_error: 0.3283                                     
[6]	valid_0's multi_error: 0.3769                                     
[7]	valid_0's multi_error: 0.5763                                     
[8]	valid_0's multi_error: 0.8408                                     
[9]	valid_0's multi_error: 0.8491                                     
[10]	valid_0's multi_error: 0.8491                                    
[11]	valid_0's multi_error: 0.8491                                    
[12]	valid_0's multi_error: 0.8491                                    
[13]	valid_0's multi_error: 0.8491                                    
[14]	valid_0's multi_error: 0.8491                                    
[15]	valid_0's multi_error: 0.8491                                    
[16]	v

[5]	valid_0's multi_error: 0.034                                      
[6]	valid_0's multi_error: 0.0334                                     
[7]	valid_0's multi_error: 0.0318                                     
[8]	valid_0's multi_error: 0.0315                                     
[9]	valid_0's multi_error: 0.031                                      
[10]	valid_0's multi_error: 0.0294                                    
[11]	valid_0's multi_error: 0.0288                                    
[12]	valid_0's multi_error: 0.0284                                    
[13]	valid_0's multi_error: 0.0281                                    
[14]	valid_0's multi_error: 0.0276                                    
[15]	valid_0's multi_error: 0.0268                                    
[16]	valid_0's multi_error: 0.027                                     
[17]	valid_0's multi_error: 0.0268                                    
[18]	valid_0's multi_error: 0.0261                                    
[19]	v

[5]	valid_0's multi_error: 0.1737                                     
[6]	valid_0's multi_error: 0.1625                                     
[7]	valid_0's multi_error: 0.1579                                     
[8]	valid_0's multi_error: 0.1533                                     
[9]	valid_0's multi_error: 0.1488                                     
[10]	valid_0's multi_error: 0.1457                                    
[11]	valid_0's multi_error: 0.1419                                    
[12]	valid_0's multi_error: 0.1365                                    
[13]	valid_0's multi_error: 0.1367                                    
[14]	valid_0's multi_error: 0.1354                                    
[15]	valid_0's multi_error: 0.1329                                    
[16]	valid_0's multi_error: 0.1329                                    
[17]	valid_0's multi_error: 0.1322                                    
[18]	valid_0's multi_error: 0.1306                                    
[19]	v

[8]	valid_0's multi_error: 0.1049                                     
[9]	valid_0's multi_error: 0.1037                                     
[10]	valid_0's multi_error: 0.1019                                    
[11]	valid_0's multi_error: 0.1004                                    
[12]	valid_0's multi_error: 0.098                                     
[13]	valid_0's multi_error: 0.097                                     
[14]	valid_0's multi_error: 0.0947                                    
[15]	valid_0's multi_error: 0.0935                                    
[16]	valid_0's multi_error: 0.0906                                    
[17]	valid_0's multi_error: 0.0895                                    
[18]	valid_0's multi_error: 0.0879                                    
[19]	valid_0's multi_error: 0.0858                                    
[20]	valid_0's multi_error: 0.0847                                    
{'bagging_fraction': 0.8583068608321964, 'btype': 'Rn', 'extra_dims': 15, 'fe

[9]	valid_0's multi_error: 0.0257                                     
[10]	valid_0's multi_error: 0.0243                                    
[11]	valid_0's multi_error: 0.0233                                    
[12]	valid_0's multi_error: 0.0225                                    
[13]	valid_0's multi_error: 0.0218                                    
[14]	valid_0's multi_error: 0.0218                                    
[15]	valid_0's multi_error: 0.0213                                    
[16]	valid_0's multi_error: 0.0208                                    
[17]	valid_0's multi_error: 0.0204                                    
[18]	valid_0's multi_error: 0.0203                                    
[19]	valid_0's multi_error: 0.02                                      
[20]	valid_0's multi_error: 0.0195                                    
NEW BEST VALUE!                                                       
{'bagging_fraction': 0.8565198438111515, 'btype': 'R', 'extra_dims': 10, 'fea

[9]	valid_0's multi_error: 0.103                                        
[10]	valid_0's multi_error: 0.101                                       
[11]	valid_0's multi_error: 0.0969                                      
[12]	valid_0's multi_error: 0.0929                                      
[13]	valid_0's multi_error: 0.0916                                      
[14]	valid_0's multi_error: 0.0887                                      
[15]	valid_0's multi_error: 0.0866                                      
[16]	valid_0's multi_error: 0.0837                                      
[17]	valid_0's multi_error: 0.0808                                      
[18]	valid_0's multi_error: 0.0791                                      
[19]	valid_0's multi_error: 0.0757                                      
[20]	valid_0's multi_error: 0.0737                                      
{'bagging_fraction': 0.6680407383897912, 'btype': 'R', 'extra_dims': 12, 'feature_fraction': 0.5818809655955655, 'lambda_l1'

[9]	valid_0's multi_error: 0.3667                                     
[10]	valid_0's multi_error: 0.602                                     
[11]	valid_0's multi_error: 0.5319                                    
[12]	valid_0's multi_error: 0.6131                                    
[13]	valid_0's multi_error: 0.6831                                    
[14]	valid_0's multi_error: 0.7671                                    
[15]	valid_0's multi_error: 0.7671                                    
[16]	valid_0's multi_error: 0.7671                                    
[17]	valid_0's multi_error: 0.7671                                    
[18]	valid_0's multi_error: 0.7671                                    
[19]	valid_0's multi_error: 0.7671                                    
[20]	valid_0's multi_error: 0.7671                                    
{'bagging_fraction': 0.6578961321393686, 'btype': 'R', 'extra_dims': 2, 'feature_fraction': 0.5326450506595152, 'lambda_l1': 9.558591020215512e-05, 

[11]	valid_0's multi_error: 0.0366                                    
[12]	valid_0's multi_error: 0.034                                     
[13]	valid_0's multi_error: 0.0323                                    
[14]	valid_0's multi_error: 0.0311                                    
[15]	valid_0's multi_error: 0.0296                                    
[16]	valid_0's multi_error: 0.0283                                    
[17]	valid_0's multi_error: 0.0276                                    
[18]	valid_0's multi_error: 0.0263                                    
[19]	valid_0's multi_error: 0.0256                                    
[20]	valid_0's multi_error: 0.0246                                    
{'bagging_fraction': 0.5497477173844217, 'btype': 'R', 'extra_dims': 9, 'feature_fraction': 0.6722778725103158, 'lambda_l1': 0, 'lambda_l2': 0.07447596284635247, 'learning_rate': 0.5604819541294173, 'metric': 'multi_error', 'min_data_in_leaf': 26.0, 'min_sum_hessian_in_leaf': 0.00860778764

[13]	valid_0's multi_error: 0.163                                     
[14]	valid_0's multi_error: 0.2153                                    
[15]	valid_0's multi_error: 0.1923                                    
[16]	valid_0's multi_error: 0.2189                                    
[17]	valid_0's multi_error: 0.1971                                    
[18]	valid_0's multi_error: 0.274                                     
[19]	valid_0's multi_error: 0.24                                      
[20]	valid_0's multi_error: 0.3091                                    
{'bagging_fraction': 0.602792171745752, 'btype': 'I', 'extra_dims': 6, 'feature_fraction': 0.5844938057001452, 'lambda_l1': 0.0001448085810537498, 'lambda_l2': 0, 'learning_rate': 0.002330260500533525, 'metric': 'multi_error', 'min_data_in_leaf': 121.0, 'min_sum_hessian_in_leaf': 0.00020032403101572378, 'num_class': 10, 'num_leaves': 1080.0, 'objective': 'multiclass'}
Overwriting param `num_class`                                

[15]	valid_0's multi_error: 0.0255                                    
[16]	valid_0's multi_error: 0.0262                                    
[17]	valid_0's multi_error: 0.0255                                    
[18]	valid_0's multi_error: 0.026                                     
[19]	valid_0's multi_error: 0.0256                                    
[20]	valid_0's multi_error: 0.0259                                    
{'bagging_fraction': 0.6527890407678644, 'btype': 'In', 'extra_dims': 4, 'feature_fraction': 0.6414071334329047, 'lambda_l1': 0, 'lambda_l2': 0.03947391221815463, 'learning_rate': 0.5281014217380049, 'metric': 'multi_error', 'min_data_in_leaf': 71.0, 'min_sum_hessian_in_leaf': 0.009984783020349273, 'num_class': 10, 'num_leaves': 521.0, 'objective': 'multiclass'}
Overwriting param `num_class`                                         
Overwriting param `objective` while setting `fobj` in train.          
Moving param `metric` to an feval.                                 

[17]	valid_0's multi_error: 0.4989                                    
[18]	valid_0's multi_error: 0.4847                                    
[19]	valid_0's multi_error: 0.4771                                    
[20]	valid_0's multi_error: 0.468                                     
{'bagging_fraction': 0.688912115001042, 'btype': 'R', 'extra_dims': 13, 'feature_fraction': 0.7013468054325842, 'lambda_l1': 0, 'lambda_l2': 0, 'learning_rate': 0.23869730185028282, 'metric': 'multi_error', 'min_data_in_leaf': 23.0, 'min_sum_hessian_in_leaf': 0.0955766243017595, 'num_class': 10, 'num_leaves': 17.0, 'objective': 'multiclass'}
Overwriting param `num_class`                                         
Overwriting param `objective` while setting `fobj` in train.          
Moving param `metric` to an feval.                                    
[1]	valid_0's multi_error: 0.1525                                     
[2]	valid_0's multi_error: 0.0889                                     
[3]	valid_0's mul

[20]	valid_0's multi_error: 0.0202                                    
{'bagging_fraction': 0.6412709353697169, 'btype': 'R', 'extra_dims': 3, 'feature_fraction': 0.6021771088404879, 'lambda_l1': 0, 'lambda_l2': 0, 'learning_rate': 0.10187305753924562, 'metric': 'multi_error', 'min_data_in_leaf': 346.0, 'min_sum_hessian_in_leaf': 3.07413535598088e-05, 'num_class': 10, 'num_leaves': 2.0, 'objective': 'multiclass'}
Overwriting param `num_class`                                         
Overwriting param `objective` while setting `fobj` in train.          
Moving param `metric` to an feval.                                    
[1]	valid_0's multi_error: 0.4231                                     
[2]	valid_0's multi_error: 0.3752                                     
[3]	valid_0's multi_error: 0.3413                                     
[4]	valid_0's multi_error: 0.3188                                     
[5]	valid_0's multi_error: 0.303                                      
[6]	valid_0's m

Overwriting param `num_class`                                         
Overwriting param `objective` while setting `fobj` in train.          
Moving param `metric` to an feval.                                    
[1]	valid_0's multi_error: 0.0925                                     
[2]	valid_0's multi_error: 0.0627                                     
[3]	valid_0's multi_error: 0.0502                                     
[4]	valid_0's multi_error: 0.0429                                     
[5]	valid_0's multi_error: 0.0388                                     
[6]	valid_0's multi_error: 0.0343                                     
[7]	valid_0's multi_error: 0.0317                                     
[8]	valid_0's multi_error: 0.0302                                     
[9]	valid_0's multi_error: 0.0293                                     
[10]	valid_0's multi_error: 0.0282                                    
[11]	valid_0's multi_error: 0.0267                                    
[12]	v

Moving param `metric` to an feval.                                    
[1]	valid_0's multi_error: 0.0908                                     
[2]	valid_0's multi_error: 0.058                                      
[3]	valid_0's multi_error: 0.0466                                     
[4]	valid_0's multi_error: 0.0718                                     
[5]	valid_0's multi_error: 0.1084                                     
[6]	valid_0's multi_error: 0.3262                                     
[7]	valid_0's multi_error: 0.2887                                     
[8]	valid_0's multi_error: 0.438                                      
[9]	valid_0's multi_error: 0.4225                                     
[10]	valid_0's multi_error: 0.6152                                    
[11]	valid_0's multi_error: 0.6726                                    
[12]	valid_0's multi_error: 0.8072                                    
[13]	valid_0's multi_error: 0.8301                                    
[14]	v

[2]	valid_0's multi_error: 0.0635                                     
[3]	valid_0's multi_error: 0.0539                                     
[4]	valid_0's multi_error: 0.0475                                     
[5]	valid_0's multi_error: 0.0429                                     
[6]	valid_0's multi_error: 0.0387                                     
[7]	valid_0's multi_error: 0.0348                                     
[8]	valid_0's multi_error: 0.0324                                     
[9]	valid_0's multi_error: 0.0297                                     
[10]	valid_0's multi_error: 0.0283                                    
[11]	valid_0's multi_error: 0.0271                                    
[12]	valid_0's multi_error: 0.025                                     
[13]	valid_0's multi_error: 0.0248                                    
[14]	valid_0's multi_error: 0.0237                                    
[15]	valid_0's multi_error: 0.0226                                    
[16]	v

[5]	valid_0's multi_error: 0.0347                                     
[6]	valid_0's multi_error: 0.0312                                     
[7]	valid_0's multi_error: 0.0287                                     
[8]	valid_0's multi_error: 0.0279                                     
[9]	valid_0's multi_error: 0.027                                      
[10]	valid_0's multi_error: 0.0264                                    
[11]	valid_0's multi_error: 0.025                                     
[12]	valid_0's multi_error: 0.0235                                    
[13]	valid_0's multi_error: 0.0217                                    
[14]	valid_0's multi_error: 0.0228                                    
[15]	valid_0's multi_error: 0.0218                                    
[16]	valid_0's multi_error: 0.0212                                    
[17]	valid_0's multi_error: 0.0204                                    
[18]	valid_0's multi_error: 0.0215                                    
[19]	v

[8]	valid_0's multi_error: 0.1066                                     
[9]	valid_0's multi_error: 0.0993                                     
[10]	valid_0's multi_error: 0.0943                                    
[11]	valid_0's multi_error: 0.0888                                    
[12]	valid_0's multi_error: 0.084                                     
[13]	valid_0's multi_error: 0.0815                                    
[14]	valid_0's multi_error: 0.0765                                    
[15]	valid_0's multi_error: 0.0729                                    
[16]	valid_0's multi_error: 0.0722                                    
[17]	valid_0's multi_error: 0.0693                                    
[18]	valid_0's multi_error: 0.0643                                    
[19]	valid_0's multi_error: 0.0623                                    
[20]	valid_0's multi_error: 0.0611                                    
{'bagging_fraction': 0.8438605806006528, 'btype': 'R', 'extra_dims': 0, 'feat

[10]	valid_0's multi_error: 0.2415                                      
[11]	valid_0's multi_error: 0.2335                                      
[12]	valid_0's multi_error: 0.2264                                      
[13]	valid_0's multi_error: 0.2182                                      
[14]	valid_0's multi_error: 0.2114                                      
[15]	valid_0's multi_error: 0.2069                                      
[16]	valid_0's multi_error: 0.2003                                      
[17]	valid_0's multi_error: 0.195                                       
[18]	valid_0's multi_error: 0.1891                                      
[19]	valid_0's multi_error: 0.1858                                      
[20]	valid_0's multi_error: 0.1812                                      
{'bagging_fraction': 0.9766052435268868, 'btype': 'R', 'extra_dims': 5, 'feature_fraction': 0.8705183167118361, 'lambda_l1': 1.8110573726078843e-05, 'lambda_l2': 0, 'learning_rate': 0.3412378660389

[10]	valid_0's multi_error: 0.0665                                      
[11]	valid_0's multi_error: 0.0659                                      
[12]	valid_0's multi_error: 0.0641                                      
[13]	valid_0's multi_error: 0.063                                       
[14]	valid_0's multi_error: 0.0619                                      
[15]	valid_0's multi_error: 0.0604                                      
[16]	valid_0's multi_error: 0.0595                                      
[17]	valid_0's multi_error: 0.0585                                      
[18]	valid_0's multi_error: 0.0578                                      
[19]	valid_0's multi_error: 0.0562                                      
[20]	valid_0's multi_error: 0.0551                                      
{'bagging_fraction': 0.5311971397755882, 'btype': 'I', 'extra_dims': 9, 'feature_fraction': 0.7433962159808519, 'lambda_l1': 0, 'lambda_l2': 0, 'learning_rate': 0.7398140961211418, 'metric': 'multi

[10]	valid_0's multi_error: 0.0921                                      
[11]	valid_0's multi_error: 0.0911                                      
[12]	valid_0's multi_error: 0.0859                                      
[13]	valid_0's multi_error: 0.0832                                      
[14]	valid_0's multi_error: 0.0814                                      
[15]	valid_0's multi_error: 0.0805                                      
[16]	valid_0's multi_error: 0.0782                                      
[17]	valid_0's multi_error: 0.0754                                      
[18]	valid_0's multi_error: 0.074                                       
[19]	valid_0's multi_error: 0.0727                                      
[20]	valid_0's multi_error: 0.0711                                      
{'bagging_fraction': 0.6156192855209107, 'btype': 'R', 'extra_dims': 6, 'feature_fraction': 0.7204662864410877, 'lambda_l1': 0, 'lambda_l2': 0, 'learning_rate': 0.009591263583834709, 'metric': 'mul

[10]	valid_0's multi_error: 0.0345                                      
[11]	valid_0's multi_error: 0.0325                                      
[12]	valid_0's multi_error: 0.0313                                      
[13]	valid_0's multi_error: 0.0299                                      
[14]	valid_0's multi_error: 0.0292                                      
[15]	valid_0's multi_error: 0.027                                       
[16]	valid_0's multi_error: 0.0267                                      
[17]	valid_0's multi_error: 0.0253                                      
[18]	valid_0's multi_error: 0.0252                                      
[19]	valid_0's multi_error: 0.0244                                      
[20]	valid_0's multi_error: 0.024                                       
{'bagging_fraction': 0.7575573579439379, 'btype': 'R', 'extra_dims': 11, 'feature_fraction': 0.8163656408228182, 'lambda_l1': 0, 'lambda_l2': 0, 'learning_rate': 0.2661475699301378, 'metric': 'mult

In [4]:
print(best_val)
print(space_eval(spc, best))

0.0179
{'bagging_fraction': 0.7801112197137942, 'btype': 'R', 'extra_dims': 13, 'feature_fraction': 0.6669223988802793, 'lambda_l1': 0, 'lambda_l2': 0, 'learning_rate': 0.154411820991592, 'metric': 'multi_error', 'min_data_in_leaf': 50.0, 'min_sum_hessian_in_leaf': 9.213964819654149e-05, 'num_class': 10, 'num_leaves': 72.0, 'objective': 'multiclass'}
