Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Skorch + HyperbandSearchCV Example #664

Open
ToddMorrill opened this issue May 13, 2020 · 13 comments
Open

Skorch + HyperbandSearchCV Example #664

ToddMorrill opened this issue May 13, 2020 · 13 comments

Comments

@ToddMorrill
Copy link

Can you post a working example that uses Skorch and HyperbandSearchCV? I haven't been able to find an actual working example.

The biggest challenge I've faced so far is how to determine the batch size being fed to the model. It's unclear if that is the chunk size and beyond that, it's unclear how Skorch's batching interacts with that.

If I run search.fit(X, y) with numpy arrays, the chunk size is massive and the grid search is very slow. If I try to chunk X, y in dask arrays, I get the following error:
ValueError: With n_samples=1, test_size=0.058823529411764705 and train_size=0.9411764705882353, the resulting train set will be empty. Adjust any of the aforementioned parameters.

@TomAugspurger
Copy link
Member

TomAugspurger commented May 13, 2020

I know that @stsievert has a full example somewhere, but I wasn't able to easily find it in his repositories. Hopefully he can find the link.

@ToddMorrill you're passing NumPy arrays to HyperbandSearchCV? That doesn't sound right. It expects dask arrays.

Can you provide a full example & traceback that gives a ValueError when passing a dask array?

@stsievert
Copy link
Member

a full example somewhere

It's the "image-denoising" model at https://github.com/stsievert/dask-hyperband-comparison/. This is a PyTorch model, defined in image-denoising/autoencoder.py.

how to determine the batch size being fed to the model

I also tune the batch size in my example. The batch size is used by PyTorch internals for optimization to approximate the loss function's gradient. The relevant line in my hyperparameters is

params = {
    ...
    'batch_size': [32, 64, 128, 256, 512],
}

The notebook that tunes the hyperparameters is at image-denoising/Run.ipynb

That doesn't sound right. It expects dask arrays.

I think Dask arrays should be passed too, but it looks like that's okay:

if isinstance(X_test, da.Array):
X_test = client.compute(X_test)
else:
X_test = yield client.scatter(X_test)

@ToddMorrill
Copy link
Author

Thanks for the rapid response @TomAugspurger, @stsievert. I actually got unlucky with my choice of chunk size. Essentially, the last chunk had a size of 1 (i.e. len(X) % chunks == 1), which clearly doesn't leave any data to train/test on. I changed chunks and resolved that issue.

I'm now facing a new error
ValueError: n_splits=5 cannot be greater than the number of members in each class.

I suspect this has to do with the number of cross validation splits and/or chunk size, but cv doesn't appear to be a parameter for HyperbandSearchCV. Here's the log output before it crashes. How are 6 and 291 computed?

[CV, bracket=1] creating 3 models
[CV, bracket=0] creating 2 models
[CV, bracket=1] For training there are between 6 and 291 examples in each chunk
[CV, bracket=0] For training there are between 6 and 291 examples in each chunk
[CV, bracket=0] validation score of 0.8671 received after 1 partial_fit calls
[CV, bracket=1] validation score of 0.9203 received after 1 partial_fit calls
[CV, bracket=1] validation score of 0.9435 received after 2 partial_fit calls
[CV, bracket=1] validation score of 0.9668 received after 6 partial_fit calls

In my case len(X) == 3217, chunk_shape == (321, 413 (i.e. first dimension is the number of examples, and second dimensions is number of features), and batch_size: [16, 32] in my parameter dictionary.

How do I choose chunks/batches to avoid these issues? Can I lower the cross validation to 3 instead of 5?

@stsievert
Copy link
Member

I'm now facing a new error

What's the traceback for the ValueError with n_splits? Incremental/Hyperband purposefully don't include any cross validation.

@ToddMorrill
Copy link
Author

Here's everything

$ search.fit(X_train, y_train)
[CV, bracket=1] creating 3 models
[CV, bracket=0] creating 2 models
[CV, bracket=1] For training there are between 6 and 291 examples in each chunk
[CV, bracket=0] For training there are between 6 and 291 examples in each chunk
[CV, bracket=0] validation score of 0.8671 received after 1 partial_fit calls
[CV, bracket=1] validation score of 0.9203 received after 1 partial_fit calls
[CV, bracket=1] validation score of 0.9435 received after 2 partial_fit calls
[CV, bracket=1] validation score of 0.9668 received after 6 partial_fit calls
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-116-36bb9df0ec09> in <module>
----> 1 search.fit(X_train, y_train, ) #, verbose=2

~/Documents/regtech/.venv/lib/python3.7/site-packages/dask_ml/model_selection/_incremental.py in fit(self, X, y, **fit_params)
    671 
    672         with context:
--> 673             return default_client().sync(self._fit, X, y, **fit_params)
    674 
    675     @if_delegate_has_method(delegate=("best_estimator_", "estimator"))

~/Documents/regtech/.venv/lib/python3.7/site-packages/distributed/client.py in sync(self, func, asynchronous, callback_timeout, *args, **kwargs)
    778         else:
    779             return sync(
--> 780                 self.loop, func, *args, callback_timeout=callback_timeout, **kwargs
    781             )
    782 

~/Documents/regtech/.venv/lib/python3.7/site-packages/distributed/utils.py in sync(loop, func, callback_timeout, *args, **kwargs)
    345     if error[0]:
    346         typ, exc, tb = error[0]
--> 347         raise exc.with_traceback(tb)
    348     else:
    349         return result[0]

~/Documents/regtech/.venv/lib/python3.7/site-packages/distributed/utils.py in f()
    329             if callback_timeout is not None:
    330                 future = asyncio.wait_for(future, callback_timeout)
--> 331             result[0] = yield future
    332         except Exception as exc:
    333             error[0] = sys.exc_info()

~/Documents/regtech/.venv/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

~/Documents/regtech/.venv/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

~/Documents/regtech/.venv/lib/python3.7/site-packages/dask_ml/model_selection/_hyperband.py in _fit(self, X, y, **fit_params)
    398 
    399         # _fit is run in parallel because it's also a tornado coroutine
--> 400         _SHAs = yield [SHAs[b]._fit(X, y, **fit_params) for b in _brackets_ids]
    401         SHAs = {b: SHA for b, SHA in zip(_brackets_ids, _SHAs)}
    402 

~/Documents/regtech/.venv/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

~/Documents/regtech/.venv/lib/python3.7/site-packages/tornado/gen.py in callback(fut)
    499             for f in children_futs:
    500                 try:
--> 501                     result_list.append(f.result())
    502                 except Exception as e:
    503                     if future.done():

~/Documents/regtech/.venv/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

~/Documents/regtech/.venv/lib/python3.7/site-packages/dask_ml/model_selection/_incremental.py in _fit(self, X, y, **fit_params)
    623             random_state=self.random_state,
    624             verbose=self.verbose,
--> 625             prefix=self.prefix,
    626         )
    627         results = self._process_results(results)

~/Documents/regtech/.venv/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

~/Documents/regtech/.venv/lib/python3.7/site-packages/tornado/gen.py in run(self)
    740                     if exc_info is not None:
    741                         try:
--> 742                             yielded = self.gen.throw(*exc_info)  # type: ignore
    743                         finally:
    744                             # Break up a reference to itself

~/Documents/regtech/.venv/lib/python3.7/site-packages/dask_ml/model_selection/_incremental.py in _fit(model, params, X_train, y_train, X_test, y_test, additional_calls, fit_params, scorer, random_state, verbose, prefix)
    231     # async for future, result in seq:
    232     for _i in itertools.count():
--> 233         metas = yield client.gather(new_scores)
    234 
    235         if log_delay and _i % int(log_delay) == 0:

~/Documents/regtech/.venv/lib/python3.7/site-packages/tornado/gen.py in run(self)
    733 
    734                     try:
--> 735                         value = future.result()
    736                     except Exception:
    737                         exc_info = sys.exc_info()

~/Documents/regtech/.venv/lib/python3.7/site-packages/distributed/client.py in _gather(self, futures, errors, direct, local_worker)
   1788                             exc = CancelledError(key)
   1789                         else:
-> 1790                             raise exception.with_traceback(traceback)
   1791                         raise exc
   1792                     if errors == "skip":

~/Documents/regtech/.venv/lib/python3.7/site-packages/dask_ml/model_selection/_incremental.py in _partial_fit()
     88         if len(X):
     89             model = deepcopy(model)
---> 90             model.partial_fit(X, y, **(fit_params or {}))
     91 
     92         meta = dict(meta)

~/Documents/regtech/.venv/lib/python3.7/site-packages/skorch/net.py in partial_fit()
    811         self.notify('on_train_begin', X=X, y=y)
    812         try:
--> 813             self.fit_loop(X, y, **fit_params)
    814         except KeyboardInterrupt:
    815             pass

~/Documents/regtech/.venv/lib/python3.7/site-packages/skorch/net.py in fit_loop()
    715 
    716         dataset_train, dataset_valid = self.get_split_datasets(
--> 717             X, y, **fit_params)
    718         on_epoch_kwargs = {
    719             'dataset_train': dataset_train,

~/Documents/regtech/.venv/lib/python3.7/site-packages/skorch/net.py in get_split_datasets()
   1199         if self.train_split:
   1200             dataset_train, dataset_valid = self.train_split(
-> 1201                 dataset, y, **fit_params)
   1202         else:
   1203             dataset_train, dataset_valid = dataset, None

~/Documents/regtech/.venv/lib/python3.7/site-packages/skorch/dataset.py in __call__()
    323             args = args + (to_numpy(y),)
    324 
--> 325         idx_train, idx_valid = next(iter(cv.split(*args, groups=groups)))
    326         dataset_train = torch.utils.data.Subset(dataset, idx_train)
    327         dataset_valid = torch.utils.data.Subset(dataset, idx_valid)

~/Documents/regtech/.venv/lib/python3.7/site-packages/sklearn/model_selection/_split.py in split()
    333                 .format(self.n_splits, n_samples))
    334 
--> 335         for train, test in super().split(X, y, groups):
    336             yield train, test
    337 

~/Documents/regtech/.venv/lib/python3.7/site-packages/sklearn/model_selection/_split.py in split()
     78         X, y, groups = indexable(X, y, groups)
     79         indices = np.arange(_num_samples(X))
---> 80         for test_index in self._iter_test_masks(X, y, groups):
     81             train_index = indices[np.logical_not(test_index)]
     82             test_index = indices[test_index]

~/Documents/regtech/.venv/lib/python3.7/site-packages/sklearn/model_selection/_split.py in _iter_test_masks()
    690 
    691     def _iter_test_masks(self, X, y=None, groups=None):
--> 692         test_folds = self._make_test_folds(X, y)
    693         for i in range(self.n_splits):
    694             yield test_folds == i

~/Documents/regtech/.venv/lib/python3.7/site-packages/sklearn/model_selection/_split.py in _make_test_folds()
    661             raise ValueError("n_splits=%d cannot be greater than the"
    662                              " number of members in each class."
--> 663                              % (self.n_splits))
    664         if self.n_splits > min_groups:
    665             warnings.warn(("The least populated class in y has only %d"

ValueError: n_splits=5 cannot be greater than the number of members in each class.

@stsievert
Copy link
Member

Thanks for the traceback. This error looks to be internal to your model. It looks like the Skorch model is doing some of it's own cross validation. It looks like passing train_split=None will resolve this: https://skorch.readthedocs.io/en/stable/user/neuralnet.html?highlight=train_split#train-split

I passed the same parameters in image-denoising/Run.ipynb:

model = TrimParams(  # wrapper around NeuralNetRegressor
    module=Autoencoder,
    criterion=torch.nn.BCELoss,
    warm_start=True,
    train_split=None,
    max_epochs=1,
    callbacks=[]
)

@ToddMorrill
Copy link
Author

Fantastic. That did it! I'm up and running now. Thank you.

@ToddMorrill
Copy link
Author

@stsievert I have some questions on some rules of thumb for hyperband.

Let's say my development dataset (train+validation) is 3000 data points and let's say that I would typically train on 80% of that data and validate on 20% and further, I'd make 5 passes through my dataset. In other words, my model converges when it sees 5*3000*.8 = 12000 data points.

My grid search parameter dictionary yields 24 unique combinations.

Let's say 250 data points is a good chunk size. Then the number of partial_fit calls required for my model to converge well would be 12000 / 250 = 48. 48 is the value I'm using for max_iter. Do you agree with this approach or would you do something different?

The results from Hyperband look great! The reason I ask is that randomized grid search in Skorch takes less time (i.e. restricted to 8 parameter combinations * 3 cross validation splits for 5 passes through my training data) than the strategy I cited above for hyperband.

Let's say I want to only try out 8 of my 24 parameter combinations in hyperband, how do I do that while being mindful of the number of training data points necessary for convergence?

@stsievert
Copy link
Member

Let's say 250 data points is a good chunk size.

When I use Hyperband, I specify the chunk size according to this rule of thumb: https://ml.dask.org/hyper-parameter-search.html#hyperband-parameters-rule-of-thumb.

My grid search parameter dictionary yields 24 unique combinations.

A space of 24 total parameters doesn't sound like much. Are they all discrete, or are some continuous? I ask because it doesn't sound like the search is "compute bound" as mentioned at https://ml.dask.org/hyper-parameter-search.html#scaling-hyperparameter-searches.

I tend to favor using continuous parameters rather than discrete parameters after reading "Random search for hyper-parameter optimization" by Bergstra and Bengio.

## not preferred
# params = {
#     "lr": [1e-3, 1e-2, 1e-1, 1e0],
#     "weight_decay": [1e-5, 1e-4],
#     "alpha": [1e-3, 1e-2, 1e-1],
# }
   
## preferred
from scipy.stats import loguniform
params = {
    "lr": loguniform(1e-3, 1e0),  # preferred
    "weight_decay": loguniform(1e-5, 1e-4),
    "alpha": loguniform(1e-3, 1e-1),
}

@ToddMorrill
Copy link
Author

Good stuff, thank you. I think that dialed it in. Now I'm getting high quality results in much less time.

I'm just running some experiments on my laptop but will scale up to a bigger search space on another machine soon. Currently it's discrete but will give the continuous sampling a try for the ones where it makes sense (e.g. dropout).

@stsievert
Copy link
Member

@ToddMorrill could you re-open this issue? Judging by your experience it'd be really useful to include an example of Skorch + Hyperband in dask-examples. I'd certainly appreciate a PR!

@ToddMorrill ToddMorrill reopened this May 14, 2020
@ToddMorrill
Copy link
Author

Sure, let me see if I can clean up my example.

@stsievert
Copy link
Member

Sure, let me see if I can clean up my example.

I'd be happy to clean up your example too. If you could make a PR, I could send in another PR to your branch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants