Skip to content

devavratTomar/TeSLA

Repository files navigation

TeSLA: Test-Time Self-Learning With Automatic Adversarial Augmentation

This repository contains official PyTorch implementation for CVPR 2023 paper TeSLA: Test-Time Self-Learning With Automatic Adversarial Augmentation by Devavrat Tomar, Guillaume Vray, Behzad Bozorgtabar, and Jean-Philippe Thiran.

Abstract

Most recent test-time adaptation methods focus on only classification tasks, use specialized network architectures, destroy model calibration or rely on lightweight information from the source domain. To tackle these issues, this paper proposes a novel Test-time Self-Learning method with automatic Adversarial augmentation dubbed TeSLA for adapting a pre-trained source model to the unlabeled streaming test data. In contrast to conventional self-learning methods based on cross-entropy, we introduce a new test-time loss function through an implicitly tight connection with the mutual information and online knowledge distillation. Furthermore, we propose a learnable efficient adversarial augmentation module that further enhances online knowledge distillation by simulating high entropy augmented images. Our method achieves state-of-the-art classification and segmentation results on several benchmarks and types of domain shifts, particularly on challenging measurement shifts of medical images. TeSLA also benefits from several desirable properties compared to competing methods in terms of calibration, uncertainty metrics, insensitivity to model architectures, and source training strategies, all supported by extensive ablations.

Overview of TeSLA Framework

(a) The student model is adapted on the test images by minimizing the proposed test-time objective . The high-quality soft-pseudo labels required by are obtained from the exponentially weighted averaged teacher model and refined using the proposed Soft-Pseudo Label Refinement (PLR) on the corresponding test images. The soft-pseudo labels are further utilized for teacher-student knowledge distillation via on the adversarially augmented views of the test images. (b) The adversarial augmentations are obtained by applying learned sub-policies sampled i.i.d from using the probability distribution with their corresponding magnitudes selected from . The parameters and of the augmentation module are updated by the unbiased gradient estimator of the loss computed on the augmented test images.

Requirements

Fist install Anaconda (Python >= 3.8) using this link. Create the following CONDA environment by running the following command:

conda create --name TeSLA python=3.8
conda activate TeSLA
conda install pip
pip install -r requirements.txt

Activate the TeSLA environment as:

conda activate TeSLA

Datasets Download Links

Dataset Name Download Link Extract to Relative Path
CIFAR-10C click here ../Datasets/cifar_dataset/CIFAR-10-C/
CIFAR-100C click here ../Datasets/cifar_dataset/CIFAR-100-C/
ImageNet-C click here ../Datasets/imagenet_dataset/
VisDA-C click here ../Datasets/visda_dataset
Kather click here ../Datasets/Kather/kather2016
VisDA-S click here ../Datasets/visda_segmentation_dataset
(MRI) Spinal Cord click here ../Datasets/MRI/SpinalCord
(MRI) Prostate click here ../Datasets/MRI/Prostate

Pre-trained Source Models Links

Classification Task

Dataset Name Download Link Extract to Relative Path
CIFAR-10 click here ../Source_classifiers/cifar10
CIFAR-100 click here ../Source_classifiers/cifar100
ImageNet PyTorch Default
VisDA-C click here ../Source_classifier/VisDA
Kather click here ../Source_classifier/Kather

Segmentation Task

Dataset Name Download Link Extract to Relative Path
VisDA-S click here ../Source_Segmentation/VisDA/
MRI (Spinal Cord and Prostate) click here ../Source_Segmentation/MRI/

Examples of adapting source models using TeSLA

Classification task on CIFAR, ImageNet, VisDA, and Kather datasets for online and offline adaptation:

(1) Common Image Corruptions: CIFAR-10C

bash scripts_classification/online/cifar10.sh
bash scripts_classification/offline/cifar10.sh

(2) Common Image Corruptions: CIFAR-100C

bash scripts_classification/online/cifar100.sh
bash scripts_classification/offline/cifar100.sh

(3) Common Image Corruptions: ImageNet-C

bash scripts_classification/online/imagenet.sh
bash scripts_classification/offline/imagenet.sh

(4) Synthetic to Real Adaptation: VisDA-C

bash scripts_classification/online/visdac.sh
bash scripts_classification/offline/visdac.sh

(5) Medical Measurement Shifts: Kather

bash scripts_classification/online/kather.sh
bash scripts_classification/offline/kather.sh

Segmentation task on VisDA-S and MRI datasets for online and offline adaptation:

(1) GTA5 to CityScapes

bash scripts_segmentation/online/cityscapes.sh
bash scripts_segmentation/offline/cityscapes.sh

(2) Domain shifts of MRI

bash scripts_segmentation/online/spinalcord.sh
bash scripts_segmentation/offline/prostate.sh

Citation

If you find our work useful, please consider citing:

@inproceedings{tomar2023TeSLA,
  title={TeSLA: Test-Time Self-Learning With Automatic Adversarial Augmentation},
  author={Tomar, Devavrat and Vray, Guillaume and Bozorgtabar, Behzad and Thiran, Jean-Philippe},
  booktitle={Proceedings of the IEEE/CVF conference on computer vision and pattern recognition (CVPR)},
  year={2023}
}