This is the official PyTorch implementation of the work: Puranik, B., Beirami, A., Qin, Y. and Madhow, U., "Improving Robustness via Tilted Exponential Layer: A Communication-Theoretic Perspective", in International Conference on Artificial Intelligence and Statistics (AISTATS 2024), pp. 4510-4518, PMLR.
State-of-the-art techniques for enhancing robustness of deep networks mostly rely on empirical risk minimization with suitable data augmentation. In this work, we propose a complementary approach motivated by communication theory, aimed at enhancing the signal-to-noise ratio at the output of a neural network layer via neural competition during learning and inference. In addition to standard empirical risk minimization, neurons compete to sparsely represent layer inputs by maximization of a tilted exponential (TEXP) objective function for the layer. TEXP learning can be interpreted as maximum likelihood estimation of matched filters under a Gaussian model for data noise. Inference in a TEXP layer is accomplished by replacing batch norm by a tilted softmax, which can be interpreted as computation of posterior probabilities for the competing signaling hypotheses represented by each neuron. After providing insights via simplified models, we show, by experimentation on standard image datasets, that TEXP learning and inference enhances robustness against noise and other common corruptions, without requiring data augmentation. Further cumulative gains in robustness against this array of distortions can be obtained by appropriately combining TEXP with data augmentation techniques.
The code to create simplified data models and apply unsupervised TEXP learning on a single layer neural network is contained in the jupyter notebook simplified_models.ipynb
Dependencies are listed in the file requirements.txt. Packages like robustbench and autoattack should be installed by cloning the original repositories at https://github.com/RobustBench/robustbench and https://github.com/fra31/auto-attack. Autoattack, robustbench are only needed for evaluation. All the results on VGG based models with the CIFAR-10 dataset were obtained using this codebase.
To install the requirements:
pip install -r requirements.txt
After cloning the repository, enter the project folder and create .env file. Then add the current directory (project directory) to .env file as:
PROJECT_DIR=<project directory>/texp_for_robustness/
All the hyperparameters and other settings are located inside the file src/configs/cifar.yaml. Some of the parameters are exposed in the shell scripts to train and evaluate models, where settings from the config file can be overriden.
The default parameters and settings to train a TEXP-VGG-16 model are loaded in src/sh/train.sh. To launch the training, execute the following from the project directory. It also saves the checkpoint.
bash src/sh/train.sh
Evaluate the model using the trained checkpoint. The parameters in eval.sh are set to match the default parameters in train.sh.
bash src/sh/eval.sh
Common corruptions evaluation function test_common_corruptions() relies upon loader functions inside robustbench --> data.py --> load_cifar10c() which have been modified to return corruptions of all severity levels, when needed. The file is located at robustbench_helper/data.py. After installing robustbench, please replace the original file at "site-packages/robustbench/data.py" with this version and appropriately modify the enums.
Tests on CIFAR-100 with WRN-28-10 and on ImageNet with ResNet-50 were performed by cloning publicly available repositories on baseline WRN and ResNet training, and adding a single layer of TEXP blocks on top of it.