Skip to content

Commit

Permalink
Update to make parameterizing the SVC more explicit
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffkinnison committed Aug 6, 2018
1 parent a5d79b3 commit 660d3b2
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions examples/svm/svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,28 @@ def main(params):
"""
# Extract the kernel name and parameters. This is just a short expression
# to get the only dictionary entry, which should have our hyperparameters.
kernel = params[list(params.keys())[0]]
kernel_params = list(params.values())[0]

# Load the training and test sets.
(X_train, y_train), (X_test, y_test) = load_data()

X_train = X_train.astype(np.float32) / 255.0
X_test = X_test.astype(np.float32) / 255.0

# Set up the SVM with its parameterized kernel. The long form of instantiation
# is done here to show what `kernel_params` looks like internally.
# This can be shortened to `svc = SVC(**kernel_params)`
svc = SVC(kernel=kernel_params['kernel'],
C=kernel_params['C'],
gamma=kernel_params['gamma'] if 'gamma' in kernel_params else None,
coef0=kernel_params['coef0'] if 'coef0' in kernel_params else None,
degree=kernel_params['degree'] if 'degree' in kernel_params else None)

# Set up parallel training across as many cores as are available on the
# worker.
s = OneVsRestClassifier(
BaggingClassifier(
SVC(**kernel),
svc,
n_estimators=10,
max_samples=0.1,
n_jobs=-1))
Expand Down

0 comments on commit 660d3b2

Please sign in to comment.