diff --git a/examples/svm/svm.py b/examples/svm/svm.py index 68129bb..22a90cd 100644 --- a/examples/svm/svm.py +++ b/examples/svm/svm.py @@ -60,7 +60,7 @@ def main(params): """ # Extract the kernel name and parameters. This is just a short expression # to get the only dictionary entry, which should have our hyperparameters. - kernel = params[list(params.keys())[0]] + kernel_params = list(params.values())[0] # Load the training and test sets. (X_train, y_train), (X_test, y_test) = load_data() @@ -68,11 +68,20 @@ def main(params): X_train = X_train.astype(np.float32) / 255.0 X_test = X_test.astype(np.float32) / 255.0 + # Set up the SVM with its parameterized kernel. The long form of instantiation + # is done here to show what `kernel_params` looks like internally. + # This can be shortened to `svc = SVC(**kernel_params)` + svc = SVC(kernel=kernel_params['kernel'], + C=kernel_params['C'], + gamma=kernel_params['gamma'] if 'gamma' in kernel_params else None, + coef0=kernel_params['coef0'] if 'coef0' in kernel_params else None, + degree=kernel_params['degree'] if 'degree' in kernel_params else None) + # Set up parallel training across as many cores as are available on the # worker. s = OneVsRestClassifier( BaggingClassifier( - SVC(**kernel), + svc, n_estimators=10, max_samples=0.1, n_jobs=-1))