Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update workflows #32

Merged
merged 35 commits into from May 2, 2022
Merged

Update workflows #32

merged 35 commits into from May 2, 2022

Conversation

jgbos
Copy link
Contributor

@jgbos jgbos commented Apr 27, 2022

See example use here: https://github.com/mit-ll-responsible-ai/responsible-ai-toolbox/blob/workflow-docs/examples/MNIST-Translation-Robustness.ipynb

  • Create base class for workflows
  • Update docs
  • Create tests for workflows

@jgbos jgbos marked this pull request as ready for review April 28, 2022 19:19
@jgbos
Copy link
Contributor Author

jgbos commented Apr 28, 2022

---------- coverage: platform linux, python 3.8.13-final-0 -----------
Name                                                        Stmts   Miss Branch BrPart  Cover   Missing
-------------------------------------------------------------------------------------------------------
src/rai_toolbox/__init__.py                                     6      0      0      0   100%
src/rai_toolbox/_typing.py                                     28      0      6      0   100%
src/rai_toolbox/_utils/__init__.py                             44      0     22      1    98%   83->86
src/rai_toolbox/_utils/itertools.py                            13      0      6      0   100%
src/rai_toolbox/_utils/stateful.py                             47      0     24      0   100%
src/rai_toolbox/_utils/tqdm.py                                  7      0      0      0   100%
src/rai_toolbox/augmentations/__init__.py                       3      0      0      0   100%
src/rai_toolbox/augmentations/augmix/__init__.py                2      0      0      0   100%
src/rai_toolbox/augmentations/augmix/_implementation.py        33      5     16      3    80%   125-126, 138-139, 146
src/rai_toolbox/augmentations/augmix/transforms.py             73      5     29      5    90%   21, 41->47, 48, 62, 67, 98
src/rai_toolbox/augmentations/fourier/__init__.py               3      0      0      0   100%
src/rai_toolbox/augmentations/fourier/_fourier_basis.py        30      0      8      1    97%   136->139
src/rai_toolbox/augmentations/fourier/_implementations.py      65     47     28      0    22%   68-86, 110-119, 141-209
src/rai_toolbox/augmentations/fourier/transforms.py            94     77     36      0    15%   46-63, 108-204, 220-244, 248
src/rai_toolbox/datasets/__init__.py                            4      0      0      0   100%
src/rai_toolbox/datasets/_cifar10_base.py                      25     11     10      0    46%   43-55, 59-61, 64
src/rai_toolbox/datasets/_imagenet_base.py                     28      4     16      0    91%   71, 81-83
src/rai_toolbox/datasets/_utils.py                             11      6      3      0    36%   15-21
src/rai_toolbox/datasets/cifar10_extensions.py                 42     23     10      0    40%   44-63, 66-75, 78-87, 95
src/rai_toolbox/datasets/cifar_corruptions.py                  63     27     14      0    52%   169-196, 199-208, 214-217, 222
src/rai_toolbox/losses/__init__.py                              2      0      0      0   100%
src/rai_toolbox/losses/_jensen_shannon_divergence.py           12      0      7      0   100%
src/rai_toolbox/losses/_utils.py                               10      0      0      0   100%
src/rai_toolbox/mushin/__init__.py                              4      0      0      0   100%
src/rai_toolbox/mushin/_compatibility.py                       13      0      4      0   100%
src/rai_toolbox/mushin/_utils.py                               56     39     28      0    23%   52-84, 111-143
src/rai_toolbox/mushin/hydra.py                                79      8     32      6    87%   49, 62, 85, 142->145, 149-150, 155, 158, 176
src/rai_toolbox/mushin/lightning/__init__.py                    3      0      0      0   100%
src/rai_toolbox/mushin/lightning/_pl_main.py                   16      0      2      0   100%
src/rai_toolbox/mushin/lightning/callbacks.py                  36      0     10      1    98%   55->60
src/rai_toolbox/mushin/lightning/launchers.py                  73      0     16      1    99%   228->232
src/rai_toolbox/mushin/testing/__init__.py                      0      0      0      0   100%
src/rai_toolbox/mushin/testing/lightning.py                    66      0     12      0   100%
src/rai_toolbox/mushin/workflows.py                           116      0     52      0   100%
src/rai_toolbox/optim/__init__.py                               4      0      0      0   100%
src/rai_toolbox/optim/frank_wolfe.py                           54      0     20      0   100%
src/rai_toolbox/optim/lp_space.py                              96      0     26      0   100%
src/rai_toolbox/optim/optimizer.py                            103      0     50      0   100%
src/rai_toolbox/perturbations/__init__.py                       4      0      0      0   100%
src/rai_toolbox/perturbations/init.py                          23      0      0      0   100%
src/rai_toolbox/perturbations/models.py                        29      0     14      0   100%
src/rai_toolbox/perturbations/solvers.py                       85      0     34      0   100%
-------------------------------------------------------------------------------------------------------
TOTAL                                                        1505    252    535     18    81%

@jgbos jgbos requested a review from rsokl April 28, 2022 20:22
@jgbos
Copy link
Contributor Author

jgbos commented Apr 29, 2022

Oh, maybe we should at least allow arguments that are ints, floats, etc and convert them to string (e.g., so I don't have to do num_nodes="4"

@rsokl
Copy link
Contributor

rsokl commented Apr 30, 2022

@jgbos also heads-up I am making significant changes to the xarray capabilities. Don't want to push yet because tests are failing (and we're almost out of minutes on Actions 😢 )

@sanjeevm345
Copy link

Added $ so hopefully we will not run out of minutes

@rsokl
Copy link
Contributor

rsokl commented Apr 30, 2022

Added $ so hopefully we will not run out of minutes

Awesome. Running out of minutes would have been a real show-stopper for our progress.

Copy link
Contributor

@rsokl rsokl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes that I made

  • Added rai_toolbox.mushin.multirun and rai_toolbox.mushin.hydra_list. These are both subclasses of list, which can be passed to run -- these are currently the only non-string sequences permitted by run. The former indicates that the contents should be iterated over in a multirun. The latter indicates that the list should be passed as-is, as a config value.
  • Permit dict inputs to run
  • Loading metrics converted all numbers to floats; now ints are preserved. We should see if there is functionality we can use from Hydra to ensure that these values are loaded faithfully.
  • Added BaseWorkflow.multi_task_overrides, which is a self-populating attributed (i.e. it loads itself upon being accessed). This loads a dict of overrides; e.g. {'epsilon': [1.0, 2.0, 3.0], "foo": "a"}, whereas self.workflow_overrides stores a dict of repeated values such that all entries are of equal length: {'epsilon': [1.0, 2.0, 3.0], "foo": ["a", "a", "a"]}.
  • Most of the RobustnessCurve logic is now generalized to MultiRunMetricsWorkflow, which RobustnessCurve inherits from.

Improved xarray functionality

Using dimensions with coordinates

Prior to bcce75b, to_xarray leveraged "non-dimension coordinates" (the default dimension name being "x") to associate multi-run coordinates and data variables. Consider the following:

from rai_toolbox.mushin.workflows import RobustnessCurve
import torch as tr

class LocalRobustness(RobustnessCurve):
    @staticmethod
    def evaluation_task(epsilon):
        val = 100 - epsilon**2

        result = dict(accuracies=val+2)

        tr.save(result, "test_metrics.pt")
        return result
>>> wf = LocalRobustness()
>>> wf.run(epsilon=[1.0, 3.0, 2.0])
[2022-05-01 12:01:00,451][HYDRA] Launching 3 jobs locally
[2022-05-01 12:01:00,451][HYDRA] 	#0 : +epsilon=1.0
[2022-05-01 12:01:00,543][HYDRA] 	#1 : +epsilon=3.0
[2022-05-01 12:01:00,640][HYDRA] 	#2 : +epsilon=2.0

>>>ds = wf.to_xarray()
>>> ds
<xarray.Dataset>
Dimensions:     (x: 3)
Coordinates:
    epsilon     (x) float64 1.0 2.0 3.0
Dimensions without coordinates: x
Data variables:
    accuracies  (x) float64 101.0 98.0 93.0

Because of this x dimension, we cannot directly query an accuracy based on its epsilon value:

>>> ds.sel(epsilon=1.0)
KeyError: 'no index found for coordinate epsilon'

As of bcce75b, to_xarray no longer introduces this spurious dimension; instead the dimensions of the dataset are explicitly determined by the mutli-run quantities:

>>> ds
<xarray.Dataset>
Dimensions:     (epsilon: 3)
Coordinates:
  * epsilon     (epsilon) float64 1.0 2.0 3.0
Data variables:
    accuracies  (epsilon) float64 101.0 98.0 93.0

And thus we can query metrics based on these coordinates:

>>> ds.sel(epsilon=[1.0, 3.0])
<xarray.Dataset>
Dimensions:     (epsilon: 2)
Coordinates:
  * epsilon     (epsilon) float64 1.0 3.0
Data variables:
    accuracies  (epsilon) float64 101.0 93.0

Using non-coordinate dimensions should never be needed by a multi-run. Because multi-runs are effectively nested for-loops, we can always count on having a "grid" of coordinate values and corresponding results that can be stored as a multi-dimensional array. E.g. if we multi-run over two checkpoints and three epsilons, each producing a scalar accuracy, then we should be able to produce a shape-(2, 3) DataArray storing these accuracies with the following structure:

Dims: ckpt, eps

       eps_1   eps_2    eps_3
ckpt_a    *       *        *
ckpt_b    *       *        *

This even works if we have non-scalar metrics being saved....

Supporting non-scalar metrics

to_xarray can now handle non-scalar metrics. We make up coordinate names for the unknown dimensions. Consider the following example where the images metrics is a shape-(4, 3) array:

class LocalRobustness(RobustnessCurve):
    @staticmethod
    def evaluation_task(epsilon):
        val = 100 - epsilon**2

        result = dict(images=[[val]*3]*4, accuracies=val+2)

        tr.save(result, "test_metrics.pt")
        return result
>>> wf = LocalRobustness()
>>> wf.run(epsilon=[1.0, 3.0, 2.0], ckpt="a,b")
[2022-05-01 12:10:16,277][HYDRA] Launching 6 jobs locally
[2022-05-01 12:10:16,277][HYDRA] 	#0 : +ckpt=a +epsilon=1.0
[2022-05-01 12:10:16,368][HYDRA] 	#1 : +ckpt=a +epsilon=3.0
[2022-05-01 12:10:16,459][HYDRA] 	#2 : +ckpt=a +epsilon=2.0
[2022-05-01 12:10:16,551][HYDRA] 	#3 : +ckpt=b +epsilon=1.0
[2022-05-01 12:10:16,641][HYDRA] 	#4 : +ckpt=b +epsilon=3.0
[2022-05-01 12:10:16,732][HYDRA] 	#5 : +ckpt=b +epsilon=2.0

>>> ds = wf.to_xarray()
>>> ds
<xarray.Dataset>
Dimensions:      (ckpt: 2, epsilon: 3, images_dim0: 4, images_dim1: 3)
Coordinates:
  * ckpt         (ckpt) <U1 'a' 'b'
  * epsilon      (epsilon) float64 1.0 2.0 3.0
  * images_dim0  (images_dim0) int32 0 1 2 3
  * images_dim1  (images_dim1) int32 0 1 2
Data variables:
    images       (ckpt, epsilon, images_dim0, images_dim1) float64 99.0 ... 91.0
    accuracies   (ckpt, epsilon) float64 101.0 98.0 93.0 101.0 98.0 93.0

>>> ds.accuracies
<xarray.DataArray 'accuracies' (ckpt: 2, epsilon: 3)>
array([[101.,  98.,  93.],
       [101.,  98.,  93.]])
Coordinates:
  * ckpt     (ckpt) <U1 'a' 'b'
  * epsilon  (epsilon) float64 1.0 2.0 3.0

>>> ds.images
<xarray.DataArray 'images' (ckpt: 2, epsilon: 3, images_dim0: 4, images_dim1: 3)>
array([[[[99., 99., 99.],
         [99., 99., 99.],
         [99., 99., 99.],
         [99., 99., 99.]],

        [[96., 96., 96.],
         [96., 96., 96.],
         [96., 96., 96.],
         [96., 96., 96.]],

        [[91., 91., 91.],
         [91., 91., 91.],
         [91., 91., 91.],
         [91., 91., 91.]]],


       [[[99., 99., 99.],
         [99., 99., 99.],
         [99., 99., 99.],
         [99., 99., 99.]],

        [[96., 96., 96.],
         [96., 96., 96.],
         [96., 96., 96.],
         [96., 96., 96.]],

        [[91., 91., 91.],
         [91., 91., 91.],
         [91., 91., 91.],
         [91., 91., 91.]]]])
Coordinates:
  * ckpt         (ckpt) <U1 'a' 'b'
  * epsilon      (epsilon) float64 1.0 2.0 3.0
  * images_dim0  (images_dim0) int32 0 1 2 3
  * images_dim1  (images_dim1) int32 0 1 2

This is all really powerful; see how we can get the results for ckpt=a:

>>> ds.sel(ckpt="a")
<xarray.Dataset>
Dimensions:      (epsilon: 3, images_dim0: 4, images_dim1: 3)
Coordinates:
    ckpt         <U1 'a'
  * epsilon      (epsilon) float64 1.0 2.0 3.0
  * images_dim0  (images_dim0) int32 0 1 2 3
  * images_dim1  (images_dim1) int32 0 1 2
Data variables:
    images       (epsilon, images_dim0, images_dim1) float64 99.0 99.0 ... 91.0
    accuracies   (epsilon) float64 101.0 98.0 93.0

Some design notes on our xarrays

This is how I think of our xarray functionality; this should live somewhere in our docs eventually.

A single metric, generated across a multi-run, is stored as a DataArray
Assuming workflow corresponds to a multi-run, each metric saved across the multi-run (e.g. "accuracy" across epsilons 1,2,3) is stored as a xarray.DataArray, whose dimensions are:

  • the names of the multi-run parameters (e.g. epsilon and ckpt)
  • the dimensions needed to describe a multi-dimensional metric. For a scalar like accuracy no dimensions are added. But if each multi-run iteration stores a shape-(H, W) perturbed image, for example, then two dimensions (of lengths H and W) will be added to the metric's data array.

The coordinates are thus:

  • the values that were iterated over in the multi-run
  • the integers needed to index any additional dimensions for non-scalar metrics (e.g. 0, 1, ..., H - 1 and 0, 1, ...., W - 1

The DataSet returned by .to_xarray() is the combination of all of the DataArrays that we created.
I.e. it stores all of the metrics saved by the workflow. Even if we only have one metric (e.g. accuracy), we still store it as a DataSet even though we could just return the DataArray

Non-multirun parameters are stored in the .attrs dictionary by default.

Support for singleton dims

Suppose that we only configure a single checkpoint for one workflow. By default ckpt will not be stored as a dimensions:

>>> wf = LocalRobustness()
>>> wf.run(epsilon=[1.0, 3.0, 2.0], ckpt="a")  
[2022-05-01 12:16:17,960][HYDRA] Launching 3 jobs locally
[2022-05-01 12:16:17,960][HYDRA] 	#0 : +ckpt=a +epsilon=1.0
[2022-05-01 12:16:18,054][HYDRA] 	#1 : +ckpt=a +epsilon=3.0
[2022-05-01 12:16:18,145][HYDRA] 	#2 : +ckpt=a +epsilon=2.0

>>> ds = wf.to_xarray()
>>> ds   # note Attributes contains ckpt
<xarray.Dataset>
Dimensions:      (epsilon: 3, images_dim0: 4, images_dim1: 3)
Coordinates:
  * epsilon      (epsilon) float64 1.0 2.0 3.0
  * images_dim0  (images_dim0) int32 0 1 2 3
  * images_dim1  (images_dim1) int32 0 1 2
Data variables:
    images       (epsilon, images_dim0, images_dim1) float64 99.0 99.0 ... 91.0
    accuracies   (epsilon) float64 101.0 98.0 93.0
Attributes:
    ckpt:     a

But we can modify this behavior and make ckpt behave like a dimension; that way we can easily concat/merge with other workflows that vary their checkpoints.

>>> wf.to_xarray(non_multirun_params_as_singleton_dims=True)  # note: ckpt is now a dimension
<xarray.Dataset>
Dimensions:      (ckpt: 1, epsilon: 3, images_dim0: 4, images_dim1: 3)
Coordinates:
  * ckpt         (ckpt) <U1 'a'
  * epsilon      (epsilon) float64 1.0 2.0 3.0
  * images_dim0  (images_dim0) int32 0 1 2 3
  * images_dim1  (images_dim1) int32 0 1 2
Data variables:
    images       (ckpt, epsilon, images_dim0, images_dim1) float64 99.0 ... 91.0
    accuracies   (ckpt, epsilon) float64 101.0 98.0 93.0
Attributes:
    ckpt:     a

@rsokl
Copy link
Contributor

rsokl commented May 1, 2022

Moving Forward

I am happy to have this be merged as-is, but moving forward I think we can make more improvements along these lines:

More self-loading/checking attributes

I added a self-populating attribute, BaseWorkflow.multirun_task_overrides. This design is nice because in other parts of the class, you can simply use this attribute without having to worry about whether or not it is loaded – it takes care of itself and its behavior can be modified surgically via inheritance.

In general, we should try to design more of the workflow's attributes in this way. It will make the workflow's statefulness much simpler to think about, and the behavior of a post-job work flow vs a loaded work flow more consistent.

Less-permissive defaults

Right now it is hard to know whether a workflow is in a "good"/populated state, especially if you used load_from_dir. For example, having working_dir default to cwd doesn't make sense; it should be set via run, load_from_dir, or explicitly by the user. My to_xarray was loading an empty array because of this default, and it was hard to debug.

Workflows should write to working-dir to make re-loading simpler

Currently RobustnessCurve has some cryptic behavior; e.g. load_from_dir has metrics_filename: str = "test_metrics.pt", which is a relatively unintuitive default. It would be nice if run() wrote a .workflow/workflow.yaml file that stored the paths of the metrics that were written, so that the workflow can be loaded without the user needing to remember what metric name was used.

There are probably lots of things that could be saved in .workflow/workflow.yaml along these lines, metrics_filename is just an example.

Standardizing how metrics are stored/saved

RobustnessCurve currently assumes that your task function returns a dict of metric-name -> metric-value. This makes sense to me, but I don't think that it is documented anywhere. Is this something that we can/should standardize? Is it only appropriate for RobustnessCurve?

Generalizing RobustnessCurve's functionality

EDIT I decided to go ahead and create MultiRunMetricsWorkflow, which RobustnessCurve inherits from.

Our to_xarray functionality is powerful and seems like it could be reused across many workflows. Could we move the functionality to a mixin? @jgbos would you be able to use the current to_xarray functionality nearly as-is for other workflows that you use? It seems like this would also hinges on the aforementioned "Standardizing how metrics are stored/saved"

In fact, I wonder if nearly all of RobustnessCurve could be generalized to a MultiRunMetricsWorkflow, which could be a general template for multi-run workflows that saves/loads metrics and makes them accessible via xarray. Indeed, the only robustness curve specific thing in the workflow is that epsilon gets special treatment in run.

Nice reprs

It would be awesome to be understand the state that a workflow at a glance, via its repr!

@rsokl
Copy link
Contributor

rsokl commented May 1, 2022

To-Do:

  • Move MultiRunMetricsWorkflow to mushin namespace
  • Add sphinx docs for MultiRunMetricsWorkflow

@rsokl rsokl merged commit b028477 into main May 2, 2022
@rsokl rsokl deleted the workflow-docs branch May 2, 2022 03:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants