This is the official implementation of the following paper:
Wenxuan Bao, Tianxin Wei, Haohan Wang, Jingrui He. Adaptive Test-Time Personalization for Federated Learning. NeurIPS 2023.
- We consider a novel setting named Test-Time Personalized Federated Learning, addressing the challenge of personalizing a global model to each unparticipating client during test-time, without requiring any labeled data.
- We propose ATP, which adaptively learns the adaptation rate for each module, enabling it to handle different types of distribution shifts among FL clients.
- python 3.8.5
- cudatoolkit 10.2.89
- cudnn 7.6.5
- pytorch 1.11.0
- torchvision 0.12.0
- numpy 1.18.5
- tqdm 4.65.0
- matplotlib 3.7.1
If you prefer generating the CIFAR-10C and CIFAR-100C by yourself, these packages may also be required:
- wandb 0.16.0
- scikit-image 0.17.2
- opencv-python 4.8.0.74
(This codebase should not be very sensitive to the version of packages.)
We consider three types of distribution shifts in our CIFAR-10C experiments: feature shift, label shift, and hybrid shift.
cd ./exp/cifar10/${shift}
where ${shift}
should be replaced by feat
(feature shift), label
(label shift), or hybrid
(hybrid shift).
./data_prepare.sh
This shell script will partition the CIFAR-10 dataset to 300 clients (240 source clients and 60 clients), and save the
partition indices to ~/data/atp/partition/cifar10/
. When there are corruptions (feature shift and hybrid shift), we
also cache the corrupted dataset to ~/data/atp/cifar10
to save time.
We also upload these
Before running ATP, we need to train a global model with source clients' training sets. We use FedAvg algorithm to train the global model.
./pretrain_fedavg_${model}.sh
Here ${model}
specifies the model architecture we use. We used resnet18
(ResNet-18) and cnn
(shallow CNN) in our
paper.
./atp_train_${model}.sh
It also prints the evaluation result of ATP-batch in each iteration.
./atp_test_${model}.sh
The results of
Notice that this is the result with one seed, while we showed the results from five difference random seeds in our paper.
ResNet-18
Algorithm | Feature shift | Label shift | Hybrid shift |
---|---|---|---|
No adaptation | 69.62 | 72.58 | 63.55 |
ATP-batch | 73.48 | 80.09 | 72.85 |
ATP-online | 73.83 | 81.72 | 75.34 |
Shallow CNN
Algorithm | Feature shift | Label shift | Hybrid shift |
---|---|---|---|
No adaptation | 64.36 | 69.15 | 61.87 |
ATP-batch | 67.02 | 76.14 | 68.48 |
ATP-online | 67.22 | 78.38 | 70.86 |
Coming soon.
Coming soon.
Coming soon.
If you are also interested in test-time personalization, please consider giving a star ⭐️ to our repo and citing our paper:
@inproceedings{
bao2023adaptive,
title={Adaptive Test-Time Personalization for Federated Learning},
author={Wenxuan Bao and Tianxin Wei and Haohan Wang and Jingrui He},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=rbw9xCU6Ci}
}