New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update workflows #32
Update workflows #32
Conversation
---------- coverage: platform linux, python 3.8.13-final-0 -----------
Name Stmts Miss Branch BrPart Cover Missing
-------------------------------------------------------------------------------------------------------
src/rai_toolbox/__init__.py 6 0 0 0 100%
src/rai_toolbox/_typing.py 28 0 6 0 100%
src/rai_toolbox/_utils/__init__.py 44 0 22 1 98% 83->86
src/rai_toolbox/_utils/itertools.py 13 0 6 0 100%
src/rai_toolbox/_utils/stateful.py 47 0 24 0 100%
src/rai_toolbox/_utils/tqdm.py 7 0 0 0 100%
src/rai_toolbox/augmentations/__init__.py 3 0 0 0 100%
src/rai_toolbox/augmentations/augmix/__init__.py 2 0 0 0 100%
src/rai_toolbox/augmentations/augmix/_implementation.py 33 5 16 3 80% 125-126, 138-139, 146
src/rai_toolbox/augmentations/augmix/transforms.py 73 5 29 5 90% 21, 41->47, 48, 62, 67, 98
src/rai_toolbox/augmentations/fourier/__init__.py 3 0 0 0 100%
src/rai_toolbox/augmentations/fourier/_fourier_basis.py 30 0 8 1 97% 136->139
src/rai_toolbox/augmentations/fourier/_implementations.py 65 47 28 0 22% 68-86, 110-119, 141-209
src/rai_toolbox/augmentations/fourier/transforms.py 94 77 36 0 15% 46-63, 108-204, 220-244, 248
src/rai_toolbox/datasets/__init__.py 4 0 0 0 100%
src/rai_toolbox/datasets/_cifar10_base.py 25 11 10 0 46% 43-55, 59-61, 64
src/rai_toolbox/datasets/_imagenet_base.py 28 4 16 0 91% 71, 81-83
src/rai_toolbox/datasets/_utils.py 11 6 3 0 36% 15-21
src/rai_toolbox/datasets/cifar10_extensions.py 42 23 10 0 40% 44-63, 66-75, 78-87, 95
src/rai_toolbox/datasets/cifar_corruptions.py 63 27 14 0 52% 169-196, 199-208, 214-217, 222
src/rai_toolbox/losses/__init__.py 2 0 0 0 100%
src/rai_toolbox/losses/_jensen_shannon_divergence.py 12 0 7 0 100%
src/rai_toolbox/losses/_utils.py 10 0 0 0 100%
src/rai_toolbox/mushin/__init__.py 4 0 0 0 100%
src/rai_toolbox/mushin/_compatibility.py 13 0 4 0 100%
src/rai_toolbox/mushin/_utils.py 56 39 28 0 23% 52-84, 111-143
src/rai_toolbox/mushin/hydra.py 79 8 32 6 87% 49, 62, 85, 142->145, 149-150, 155, 158, 176
src/rai_toolbox/mushin/lightning/__init__.py 3 0 0 0 100%
src/rai_toolbox/mushin/lightning/_pl_main.py 16 0 2 0 100%
src/rai_toolbox/mushin/lightning/callbacks.py 36 0 10 1 98% 55->60
src/rai_toolbox/mushin/lightning/launchers.py 73 0 16 1 99% 228->232
src/rai_toolbox/mushin/testing/__init__.py 0 0 0 0 100%
src/rai_toolbox/mushin/testing/lightning.py 66 0 12 0 100%
src/rai_toolbox/mushin/workflows.py 116 0 52 0 100%
src/rai_toolbox/optim/__init__.py 4 0 0 0 100%
src/rai_toolbox/optim/frank_wolfe.py 54 0 20 0 100%
src/rai_toolbox/optim/lp_space.py 96 0 26 0 100%
src/rai_toolbox/optim/optimizer.py 103 0 50 0 100%
src/rai_toolbox/perturbations/__init__.py 4 0 0 0 100%
src/rai_toolbox/perturbations/init.py 23 0 0 0 100%
src/rai_toolbox/perturbations/models.py 29 0 14 0 100%
src/rai_toolbox/perturbations/solvers.py 85 0 34 0 100%
-------------------------------------------------------------------------------------------------------
TOTAL 1505 252 535 18 81% |
Oh, maybe we should at least allow arguments that are ints, floats, etc and convert them to string (e.g., so I don't have to do |
… log non-multirun configs as atts
@jgbos also heads-up I am making significant changes to the xarray capabilities. Don't want to push yet because tests are failing (and we're almost out of minutes on Actions 😢 ) |
Added $ so hopefully we will not run out of minutes |
Awesome. Running out of minutes would have been a real show-stopper for our progress. |
…ay for loaded workflow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes that I made
- Added
rai_toolbox.mushin.multirun
andrai_toolbox.mushin.hydra_list
. These are both subclasses oflist
, which can be passed torun
-- these are currently the only non-string sequences permitted byrun
. The former indicates that the contents should be iterated over in a multirun. The latter indicates that the list should be passed as-is, as a config value. - Permit
dict
inputs torun
- Loading metrics converted all numbers to floats; now ints are preserved. We should see if there is functionality we can use from Hydra to ensure that these values are loaded faithfully.
- Added
BaseWorkflow.multi_task_overrides
, which is a self-populating attributed (i.e. it loads itself upon being accessed). This loads a dict of overrides; e.g.{'epsilon': [1.0, 2.0, 3.0], "foo": "a"}
, whereasself.workflow_overrides
stores a dict of repeated values such that all entries are of equal length:{'epsilon': [1.0, 2.0, 3.0], "foo": ["a", "a", "a"]}
. - Most of the
RobustnessCurve
logic is now generalized toMultiRunMetricsWorkflow
, whichRobustnessCurve
inherits from.
Improved xarray functionality
Using dimensions with coordinates
Prior to bcce75b, to_xarray
leveraged "non-dimension coordinates" (the default dimension name being "x"
) to associate multi-run coordinates and data variables. Consider the following:
from rai_toolbox.mushin.workflows import RobustnessCurve
import torch as tr
class LocalRobustness(RobustnessCurve):
@staticmethod
def evaluation_task(epsilon):
val = 100 - epsilon**2
result = dict(accuracies=val+2)
tr.save(result, "test_metrics.pt")
return result
>>> wf = LocalRobustness()
>>> wf.run(epsilon=[1.0, 3.0, 2.0])
[2022-05-01 12:01:00,451][HYDRA] Launching 3 jobs locally
[2022-05-01 12:01:00,451][HYDRA] #0 : +epsilon=1.0
[2022-05-01 12:01:00,543][HYDRA] #1 : +epsilon=3.0
[2022-05-01 12:01:00,640][HYDRA] #2 : +epsilon=2.0
>>>ds = wf.to_xarray()
>>> ds
<xarray.Dataset>
Dimensions: (x: 3)
Coordinates:
epsilon (x) float64 1.0 2.0 3.0
Dimensions without coordinates: x
Data variables:
accuracies (x) float64 101.0 98.0 93.0
Because of this x
dimension, we cannot directly query an accuracy based on its epsilon value:
>>> ds.sel(epsilon=1.0)
KeyError: 'no index found for coordinate epsilon'
As of bcce75b, to_xarray
no longer introduces this spurious dimension; instead the dimensions of the dataset are explicitly determined by the mutli-run quantities:
>>> ds
<xarray.Dataset>
Dimensions: (epsilon: 3)
Coordinates:
* epsilon (epsilon) float64 1.0 2.0 3.0
Data variables:
accuracies (epsilon) float64 101.0 98.0 93.0
And thus we can query metrics based on these coordinates:
>>> ds.sel(epsilon=[1.0, 3.0])
<xarray.Dataset>
Dimensions: (epsilon: 2)
Coordinates:
* epsilon (epsilon) float64 1.0 3.0
Data variables:
accuracies (epsilon) float64 101.0 93.0
Using non-coordinate dimensions should never be needed by a multi-run. Because multi-runs are effectively nested for-loops, we can always count on having a "grid" of coordinate values and corresponding results that can be stored as a multi-dimensional array. E.g. if we multi-run over two checkpoints and three epsilons, each producing a scalar accuracy, then we should be able to produce a shape-(2, 3) DataArray storing these accuracies with the following structure:
Dims: ckpt, eps
eps_1 eps_2 eps_3
ckpt_a * * *
ckpt_b * * *
This even works if we have non-scalar metrics being saved....
Supporting non-scalar metrics
to_xarray
can now handle non-scalar metrics. We make up coordinate names for the unknown dimensions. Consider the following example where the images
metrics is a shape-(4, 3) array:
class LocalRobustness(RobustnessCurve):
@staticmethod
def evaluation_task(epsilon):
val = 100 - epsilon**2
result = dict(images=[[val]*3]*4, accuracies=val+2)
tr.save(result, "test_metrics.pt")
return result
>>> wf = LocalRobustness()
>>> wf.run(epsilon=[1.0, 3.0, 2.0], ckpt="a,b")
[2022-05-01 12:10:16,277][HYDRA] Launching 6 jobs locally
[2022-05-01 12:10:16,277][HYDRA] #0 : +ckpt=a +epsilon=1.0
[2022-05-01 12:10:16,368][HYDRA] #1 : +ckpt=a +epsilon=3.0
[2022-05-01 12:10:16,459][HYDRA] #2 : +ckpt=a +epsilon=2.0
[2022-05-01 12:10:16,551][HYDRA] #3 : +ckpt=b +epsilon=1.0
[2022-05-01 12:10:16,641][HYDRA] #4 : +ckpt=b +epsilon=3.0
[2022-05-01 12:10:16,732][HYDRA] #5 : +ckpt=b +epsilon=2.0
>>> ds = wf.to_xarray()
>>> ds
<xarray.Dataset>
Dimensions: (ckpt: 2, epsilon: 3, images_dim0: 4, images_dim1: 3)
Coordinates:
* ckpt (ckpt) <U1 'a' 'b'
* epsilon (epsilon) float64 1.0 2.0 3.0
* images_dim0 (images_dim0) int32 0 1 2 3
* images_dim1 (images_dim1) int32 0 1 2
Data variables:
images (ckpt, epsilon, images_dim0, images_dim1) float64 99.0 ... 91.0
accuracies (ckpt, epsilon) float64 101.0 98.0 93.0 101.0 98.0 93.0
>>> ds.accuracies
<xarray.DataArray 'accuracies' (ckpt: 2, epsilon: 3)>
array([[101., 98., 93.],
[101., 98., 93.]])
Coordinates:
* ckpt (ckpt) <U1 'a' 'b'
* epsilon (epsilon) float64 1.0 2.0 3.0
>>> ds.images
<xarray.DataArray 'images' (ckpt: 2, epsilon: 3, images_dim0: 4, images_dim1: 3)>
array([[[[99., 99., 99.],
[99., 99., 99.],
[99., 99., 99.],
[99., 99., 99.]],
[[96., 96., 96.],
[96., 96., 96.],
[96., 96., 96.],
[96., 96., 96.]],
[[91., 91., 91.],
[91., 91., 91.],
[91., 91., 91.],
[91., 91., 91.]]],
[[[99., 99., 99.],
[99., 99., 99.],
[99., 99., 99.],
[99., 99., 99.]],
[[96., 96., 96.],
[96., 96., 96.],
[96., 96., 96.],
[96., 96., 96.]],
[[91., 91., 91.],
[91., 91., 91.],
[91., 91., 91.],
[91., 91., 91.]]]])
Coordinates:
* ckpt (ckpt) <U1 'a' 'b'
* epsilon (epsilon) float64 1.0 2.0 3.0
* images_dim0 (images_dim0) int32 0 1 2 3
* images_dim1 (images_dim1) int32 0 1 2
This is all really powerful; see how we can get the results for ckpt=a
:
>>> ds.sel(ckpt="a")
<xarray.Dataset>
Dimensions: (epsilon: 3, images_dim0: 4, images_dim1: 3)
Coordinates:
ckpt <U1 'a'
* epsilon (epsilon) float64 1.0 2.0 3.0
* images_dim0 (images_dim0) int32 0 1 2 3
* images_dim1 (images_dim1) int32 0 1 2
Data variables:
images (epsilon, images_dim0, images_dim1) float64 99.0 99.0 ... 91.0
accuracies (epsilon) float64 101.0 98.0 93.0
Some design notes on our xarrays
This is how I think of our xarray functionality; this should live somewhere in our docs eventually.
A single metric, generated across a multi-run, is stored as a DataArray
Assuming workflow corresponds to a multi-run, each metric saved across the multi-run (e.g. "accuracy" across epsilons 1,2,3) is stored as a xarray.DataArray
, whose dimensions are:
- the names of the multi-run parameters (e.g.
epsilon
andckpt
) - the dimensions needed to describe a multi-dimensional metric. For a scalar like
accuracy
no dimensions are added. But if each multi-run iteration stores a shape-(H, W) perturbed image, for example, then two dimensions (of lengths H and W) will be added to the metric's data array.
The coordinates are thus:
- the values that were iterated over in the multi-run
- the integers needed to index any additional dimensions for non-scalar metrics (e.g.
0, 1, ..., H - 1
and0, 1, ...., W - 1
The DataSet returned by .to_xarray()
is the combination of all of the DataArrays
that we created.
I.e. it stores all of the metrics saved by the workflow. Even if we only have one metric (e.g. accuracy
), we still store it as a DataSet even though we could just return the DataArray
Non-multirun parameters are stored in the .attrs
dictionary by default.
Support for singleton dims
Suppose that we only configure a single checkpoint for one workflow. By default ckpt
will not be stored as a dimensions:
>>> wf = LocalRobustness()
>>> wf.run(epsilon=[1.0, 3.0, 2.0], ckpt="a")
[2022-05-01 12:16:17,960][HYDRA] Launching 3 jobs locally
[2022-05-01 12:16:17,960][HYDRA] #0 : +ckpt=a +epsilon=1.0
[2022-05-01 12:16:18,054][HYDRA] #1 : +ckpt=a +epsilon=3.0
[2022-05-01 12:16:18,145][HYDRA] #2 : +ckpt=a +epsilon=2.0
>>> ds = wf.to_xarray()
>>> ds # note Attributes contains ckpt
<xarray.Dataset>
Dimensions: (epsilon: 3, images_dim0: 4, images_dim1: 3)
Coordinates:
* epsilon (epsilon) float64 1.0 2.0 3.0
* images_dim0 (images_dim0) int32 0 1 2 3
* images_dim1 (images_dim1) int32 0 1 2
Data variables:
images (epsilon, images_dim0, images_dim1) float64 99.0 99.0 ... 91.0
accuracies (epsilon) float64 101.0 98.0 93.0
Attributes:
ckpt: a
But we can modify this behavior and make ckpt
behave like a dimension; that way we can easily concat/merge with other workflows that vary their checkpoints.
>>> wf.to_xarray(non_multirun_params_as_singleton_dims=True) # note: ckpt is now a dimension
<xarray.Dataset>
Dimensions: (ckpt: 1, epsilon: 3, images_dim0: 4, images_dim1: 3)
Coordinates:
* ckpt (ckpt) <U1 'a'
* epsilon (epsilon) float64 1.0 2.0 3.0
* images_dim0 (images_dim0) int32 0 1 2 3
* images_dim1 (images_dim1) int32 0 1 2
Data variables:
images (ckpt, epsilon, images_dim0, images_dim1) float64 99.0 ... 91.0
accuracies (ckpt, epsilon) float64 101.0 98.0 93.0
Attributes:
ckpt: a
Moving ForwardI am happy to have this be merged as-is, but moving forward I think we can make more improvements along these lines: More self-loading/checking attributesI added a self-populating attribute, BaseWorkflow.multirun_task_overrides. This design is nice because in other parts of the class, you can simply use this attribute without having to worry about whether or not it is loaded – it takes care of itself and its behavior can be modified surgically via inheritance. In general, we should try to design more of the workflow's attributes in this way. It will make the workflow's statefulness much simpler to think about, and the behavior of a post-job work flow vs a loaded work flow more consistent. Less-permissive defaultsRight now it is hard to know whether a workflow is in a "good"/populated state, especially if you used Workflows should write to working-dir to make re-loading simplerCurrently There are probably lots of things that could be saved in Standardizing how metrics are stored/saved
Generalizing
|
To-Do:
|
See example use here: https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/blob/workflow-docs/examples/MNIST-Translation-Robustness.ipynb