Skip to content

Commit

Permalink
Added jaxline pipeline to train adversarially robust models.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 383399487
  • Loading branch information
Sven Gowal authored and derpson committed Jul 14, 2021
1 parent d8df415 commit 5909da5
Show file tree
Hide file tree
Showing 36 changed files with 5,229 additions and 79 deletions.
78 changes: 70 additions & 8 deletions adversarial_robustness/README.md
Expand Up @@ -13,7 +13,7 @@ We have released our top-performing models in two formats compatible with
[JAX](https://github.com/google/jax) and [PyTorch](https://pytorch.org/).
This repository also contains our model definitions.

## Running the example code
## Running the code

### Downloading a model

Expand Down Expand Up @@ -47,18 +47,80 @@ The following table contains the models from **Rebuffi et al., 2021**.
| CIFAR-100 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-70-16 | &#x2717; | 63.56% | 34.64% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn70-16_cutmix_ddpm.pt)
| CIFAR-100 | &#8467;<sub>&infin;</sub> | 8 / 255 | WRN-28-10 | &#x2717; | 62.41% | 32.06% | [jax](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.npy), [pt](https://storage.googleapis.com/dm-adversarial-robustness/cifar100_linf_wrn28-10_cutmix_ddpm.pt)

### Using the model
### Installing

Once downloaded, a model can be evaluated (clean accuracy) by running the
`eval.py` script in either the `jax` or `pytorch` folders. E.g.:
The following has been tested using Python 3.9.2.
Using `run.sh` will create and activate a virtualenv, install all necessary
dependencies and run a test program to ensure that you can import all the
modules.

```
# Run from the parent directory.
sh adversarial_robustness/run.sh
```

To run the provided code, use this virtualenv:

```
source /tmp/adversarial_robustness_venv/bin/activate
```

You may want to edit `requirements.txt` before running `run.sh` if GPU support
is needed (e.g., use `jaxline==0.1.67+cuda111`). See JAX's installation
[instructions](https://github.com/google/jax#installation) for more details.

### Using pre-trained models

Once downloaded, a model can be evaluated by running the `eval.py` script in
either the `jax` or `pytorch` folders. E.g.:

```
cd jax
python3 eval.py \
--ckpt=${PATH_TO_CHECKPOINT} --depth=70 --width=16 --dataset=cifar10
```

## Generated datasets
These models are also directly available within
[RobustBench](https://github.com/RobustBench/robustbench#model-zoo-quick-tour)'s
model zoo.

### Training your own model

We also provide a training pipeline that reproduces results from both
publications. This pipeline uses [Jaxline](https://github.com/deepmind/jaxline)
and is written using [JAX](https://github.com/google/jax) and
[Haiku](https://github.com/deepmind/dm-haiku). To train a model, modify the
configuration in the `get_config()` function of `jax/experiment.py` and issue
the following command from within the virtualenv created above:

```
cd jax
python3 train.py --config=experiment.py
```

The training pipeline can run with multiple worker machines and multiple devices
(either GPU or TPU). See [Jaxline](https://github.com/deepmind/jaxline) for more
details.

We do not provide a PyTorch implementation of our training pipeline. However,
you may find one on GitHub, e.g.,
[adversarial_robustness_pytorch](https://github.com/imrahulr/adversarial_robustness_pytorch)
(by Rahul Rade).

## Datasets

### Extracted dataset

Gowal et al. (2020) use samples extracted from
[TinyImages-80M](https://groups.csail.mit.edu/vision/TinyImages/).
Unfortunately, since then, the official TinyImages-80M dataset has been
withdrawn (due to the presence of offensive images). As such, we cannot provide
a download link to our extrated data until we have manually verified that all
extracted images are not offensive. If you want to reproduce our setup, consider
the generated datasets below. We are also happy to help, so feel free to reach
out to Sven Gowal directly.

### Generated datasets

Rebuffi et al. (2021) use samples generated by a Denoising Diffusion
Probabilistic Model [(DDPM; Ho et al., 2020)](https://arxiv.org/abs/2006.11239)
Expand All @@ -82,8 +144,8 @@ labels = npzfile['label']

## Citing this work

If you use this code, data or these models in your work, please cite the
relevant accompanying paper:
If you use this code (or any derived code), data or these models in your work,
please cite the relevant accompanying paper:

```
@article{gowal2020uncovering,
Expand All @@ -95,7 +157,7 @@ relevant accompanying paper:
}
```

or
and/or

```
@article{rebuffi2021fixing,
Expand Down

0 comments on commit 5909da5

Please sign in to comment.