In [1]:
import numpy as np
from si.io.csv_file import read_csv
from si.models.logistic_regression import LogisticRegression
from si.model_selection.cross_validation import k_fold_cross_validation
from si.model_selection.randomized_search import randomized_search_cv

In [2]:
# datasets
breast_bin_dataset = read_csv('../datasets/breast_bin/breast-bin.csv', features=False, label=True)

In [3]:
# cross validation
lg = LogisticRegression()
scores = k_fold_cross_validation(lg, breast_bin_dataset, cv=5)
scores

[0.9568345323741008,
 0.9712230215827338,
 0.9424460431654677,
 0.9928057553956835,
 0.9712230215827338]

In [4]:
l2_ = np.linspace(1, 10, 10)
alpha_ = np.linspace(0.001, 0.0001, 100)
max_iter_ = np.linspace(1000, 2000, 200)

In [5]:
# random search cv

lg = LogisticRegression()

# parameter grid
parameter_grid = {
    'l2_penalty': l2_,
    'alpha': alpha_,
    'max_iter': max_iter_
}

# cross validate the model
scores = randomized_search_cv(lg,
                        breast_bin_dataset,
                        hyperparameter_grid=parameter_grid,
                        cv=3,
                        n_iter=10)

scores

{'scores': [0.9669540229885057,
  0.9669540229885057,
  0.9669540229885057,
  0.9669540229885057,
  0.9669540229885057,
  0.9669540229885056,
  0.9669540229885057,
  0.9669540229885057,
  0.9669540229885057,
  0.9669540229885057],
 'hyperparameters': [{'l2_penalty': 8.0,
   'alpha': 0.0002181818181818182,
   'max_iter': 1286.43216080402},
  {'l2_penalty': 10.0,
   'alpha': 0.0008454545454545455,
   'max_iter': 1391.9597989949748},
  {'l2_penalty': 8.0,
   'alpha': 0.00012727272727272728,
   'max_iter': 1261.3065326633166},
  {'l2_penalty': 9.0,
   'alpha': 0.0005181818181818182,
   'max_iter': 1075.3768844221106},
  {'l2_penalty': 4.0,
   'alpha': 0.0007818181818181818,
   'max_iter': 1738.6934673366836},
  {'l2_penalty': 4.0,
   'alpha': 0.0005727272727272727,
   'max_iter': 1336.6834170854272},
  {'l2_penalty': 6.0,
   'alpha': 0.00030909090909090914,
   'max_iter': 1834.1708542713568},
  {'l2_penalty': 6.0,
   'alpha': 0.00032727272727272726,
   'max_iter': 1854.2713567839196},
  {'